Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prediction utilities #560

Closed
wants to merge 5 commits into from
Closed

Add prediction utilities #560

wants to merge 5 commits into from

Conversation

ashnair1
Copy link
Collaborator

@ashnair1 ashnair1 commented Jun 7, 2022

The main aim of this PR is to add utilities that will allow users to use a torchgeo segmentation model to create predictions on GeoTIFFs images of arbitrary size.

Example using a model trained on the InriaAerialImageLabelling dataset shown below. Images in this dataset are of size 5000 x 5000 and the model was trained on 512 x 512 crops.

image

@github-actions github-actions bot added datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets labels Jun 7, 2022
@adamjstewart
Copy link
Collaborator

@ashnair1 great image, do you mind if we use this in a blog to showcase the cool things TorchGeo can do?

SemanticSegmentationTask,
)

TASK_TO_MODULES_MAPPING: Dict[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that this is in two different scripts, we should move it somewhere where both scripts can find it to avoid code duplication. How about a torchgeo/common.py file? I don't want to put it in torchgeo/__init__.py because this will be sourced on every import.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like the idea of torchgeo/common.py. I can add that in a follow up PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to this

predict.py Outdated


def write_mask(mask: torch.Tensor, output_dir: str, input_filename: str) -> None:
"""Write mask to specified output directory."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring needs descriptions for all inputs, same for other functions

Comment on lines +80 to +99
conf = OmegaConf.load(os.path.join(config_dir, "experiment_config.yaml"))
ckpt = os.path.join(config_dir, "last.ckpt")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these filenames be parameters?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ckpt filename sure. Think the config filename ("experiment_config.yaml") is hard coded.

help="Path to output_directory to save predicted mask geotiffs",
)

parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we should let PyTorch Lightning handle?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps. LightningModules shouldn't need this but I'll need to verify. Will get back to this.

Copy link
Collaborator Author

@ashnair1 ashnair1 Jun 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While LightningModules are aware of which device they're on, the models (UNet, DeepLab etc) are not. Since we can't forward device info, this will be required.


if __name__ == "__main__":
# Taken from https://github.com/pangeo-data/cog-best-practices
_rasterio_best_practices = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also be moved to torchgeo/common.py

)


class PredictDataset(Dataset[Any]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify why this is necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PredictDataset represents a dataset of images in a folder that you want to run inference on. When you run predict.py you need to pass a predict_on flag. This can be either a test split of a dataset or a folder of images that you want to run inference on. PredictDataset handles the latter.

It accepts a bands argument since we can't predict the number of bands the geotiffs will have and one transform (patch_sample) that it will always do since we can't predict the size of the geotiff.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm still not understanding. Why not use the normal data loader for that dataset (a subclass of either VisionDataset or GeoDataset)? RasterDataset in combination with GridGeoSampler (or just a normal VisionDataset) is specifically designed to handle "a folder of images you want to run inference on". It seems like this won't work for all datasets since it only handles flat directories of images that can be opened with rasterio and don't need to be chipped.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the normal data loader for that dataset (a subclass of either VisionDataset or GeoDataset)?

Because the dataloader of that dataset has custom logic for parsing filenames and subdirectories that is not relevant when you want to run inference on a directory of images. You could add code to the _load_files method of each dataset that skips the parsing logic if it is in "predict" mode and I could see a case being made for that.

I just found having a separate class (PredictDataset) that represents a folder of images to be easier. Just need to instantiate it in the corresponding datamodule as seen in this PR.

It seems like this won't work for all datasets since it only handles flat directories of images that can be opened with rasterio and don't need to be chipped.

That's the exact use case being targeted. You have a bunch of images of arbitrary size in a folder and you want to produce inference results.

These images will of course need to be chipped to the size the model expects which is why the PredictDataset contains a single transform that it will always apply - patch_sample which extracts tensor patches according to the patch_size the model was trained on.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the dataloader of that dataset has custom logic for parsing filenames and subdirectories that is not relevant when you want to run inference on a directory of images.

But this doesn't cause any issues, right? Yes, you could use a simpler class for parsing a directory of images, but we already have datasets for everything. If there's no reason why those existing datasets don't work then I don't see why we need a new dataset.

Copy link
Collaborator Author

@ashnair1 ashnair1 Jul 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like you want to train a model on Inria and then make predictions on a completely different dataset with similar images

Precisely. For example, one of the locations within the inria dataset is Austin. So if someone gives me some aerial imagery with similar resolution (30cm) over Dallas, Houston or even Austin, I want to be able to use this model to run inference.

That's fine, but maybe a name other than PredictDataset would be better.

Maybe PredictionFolder? I'm open to suggestions.

I was under the mistaken impression that PredictDataset would be required for all datasets you want to use in predict.py.

Ah I see. PredictDataset was meant to be a standalone class to be used when you want to run inference on a folder of images of arbitrary size and variable number of bands. You would inject it into the datamodule so that it can be used with predict.py as seen in this PR. It's not connected to any dataset.

We already have RasterDataset and VectorDataset subclasses of GeoDataset that recursively locate files that can be opened with rasterio and fiona. Maybe we need a new ImageDataset subclass of VisionDataset that recursively locates files that can be opened with PIL?

Do we need PIL? I think we can stick with rasterio but not use RasterDataset. Geospatial info is not necessary for the prediction process but we don't want the process to fail if the image lacks geospatial info. Apart from throwing a NotGeoreferencedWarning, rasterio can read and write any image with/without geospatial info.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, didn't even notice that you were using rasterio. In that case, this class definitely isn't needed. You should definitely use RasterDataset for this specific dataset. You might not need geospatial info if you only want to predict on one file, but as soon as you want to predict on multiple files, you need geospatial info to stitch together the predictions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the file in question does not have geospatial info then RasterDataset can't be used. Also the patching and merging/stitching are done by kornia's ExtractTensorPatches and CombineTensorPatches.

Just so we're on the same page this is how I visualise the workflow:

The main inputs to the predict.py script are a directory of images (inference folder) and the trained model. The directory of images contain images with/without georeference and with different sizes.

The script first runs over each image, patches/chips them to the appropriate size for the model, runs model inference and merges them back together.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works fine if you want to convert a directory of images to a directory of predictions, but what if you want to predict a single static map of an entire country by stitching together those predictions? I think that's where predict.py would really shine.

I still think the best approach is to use RasterDataset for geospatial data and some kind of subclass of VisionDataset for non-geospatial data. Also, these datasets could be used at training time for some kind of self-supervised learning, so PredictDataset wouldn't be the best name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we're coming at this from two different angles.

From my POV, predict.py was only intended to convert a directory of images to a directory of predictions. The motivation was simple: I have a model and I have images from certain areas of interest that I want to run inference on. I might want to merge them together but I could just use gdal_merge.py for that.

The images could be spatially distributed i.e. AOI1.tif, AOI2.tif etc or temporally distributed i.e. AOI_2020.tif, AOI_2019.tif, AOI_2018.tif.

Another place where this would be useful is when you want to upload predictions to an evaluation server. In Inria, the images are of size 5000 x 5000. For training I used 512 x 512 crops. But if I want to upload the predictions to an evaluation server they would expect predictions to be the same size as input i.e. 5000 x 5000.

Predicting a single static map or using it for self supervised learning is well outside my original scope.

In order to properly address this, there needs to some consensus on what prediction scenarios torchgeo is interested in supporting:

  1. Will it support images with and without georeference?
  2. Will it support running on a directory of images and producing a directory of predictions?
  3. Will it provide an option to merge the directory of predictions if they have geo-reference into a single static map?

)
patch_shape = cast(Tuple[int, int], tensor_to_int(batch["patch_shape"]))
padding = cast(Tuple[int, int], tensor_to_int(batch["padding"]))
patch_combine = CombineTensorPatches(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you tell me more about how this function works? What happens if your patches have overlap? Is this sufficient to close #30 or do we need something more powerful/generic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you tell me more about how this function works?

I've written a tutorial describing how extract and combine patches work here

What happens if your patches have overlap?

Currently you can extract patches with overlap (via the stride parameter) but you can't merge them together. This is because CombineTensorPatches currently only supports stride=window_size as seen here

Is this sufficient to close #30 or do we need something more powerful/generic?

Based on my understanding of the scope of #30, I would say no. If the goal is to just enable users to extract patches and stitch them together, kornia's ExtractTensorPatches and CombineTensorPatches are sufficient. Once CombineTensorPatches supports stride!=window_size, we will be able to handle patches with overlap.

But for the alternate stitching techniques (like label averaging) mentioned in the paper referenced in #30, we might need something more powerful as CombineTensorPatches doesn't support this.

@adamjstewart adamjstewart added this to the 0.3.0 milestone Jun 12, 2022
@ashnair1
Copy link
Collaborator Author

@ashnair1 great image, do you mind if we use this in a blog to showcase the cool things TorchGeo can do?

Not at all 👍

But I just want to mention that the image on the left is an image from the inria dataset (I think bellingham1.tif) so it's not my image 😃

@adamjstewart
Copy link
Collaborator

Which model did you use to produce those masks in the image?

@ashnair1
Copy link
Collaborator Author

Which model did you use to produce those masks in the image?

U-Net with resnet18 backbone. Think I used the exact same config as the one in conf/inria.yaml

@ashnair1 ashnair1 force-pushed the predict branch 2 times, most recently from 647ef0e to 4e7e47f Compare June 20, 2022 08:07
@adamjstewart adamjstewart mentioned this pull request Jun 24, 2022
@github-actions github-actions bot added the testing Continuous integration testing label Jun 29, 2022
@ashnair1 ashnair1 marked this pull request as ready for review June 29, 2022 10:08
@calebrob6
Copy link
Member

I'm going to pull the thread that is going on above down here so it's easier to follow:

Part 1

I think most of the discussion is nicely summarized by the three questions @ashnair1 is bringing up:

In order to properly address this, there needs to some consensus on what prediction scenarios torchgeo is interested in supporting:

  1. Will it support images with and without georeference?
  2. Will it support running on a directory of images and producing a directory of predictions?
  3. Will it provide an option to merge the directory of predictions if they have geo-reference into a single static map?

I think (point-by-point):

  1. Yes, anything that rasterio (libgdal) can read should be supported.
  2. Yes, why not, this is a common enough use case I would imagine
  3. No, this is massively out of scope IMHO (@ashnair1 points out that you could run on a directory then merge with gdal_merge if you realllllly need to do this)

I think we can do this without adding a PredictionDataset (or similar) class pretty easily, but let me double check by trying to implement this and get back to you! Of note, I think whatever we come up with will need to assume the existence of a preprocess(sample) method in all of our DataModules. Different SemanticSegmentationTask instances will use different DataModules and will thus expect different pre-processing methods to be applied to their inputs.

Part 2

Separate to the above questions, there is the question of whether we should allow treating the test DataLoader of one of our existing DataModules as a "target" for this script. I think @ashnair1's specific use case will be solved by a script that implements points 1 and 2 above, so this might be an irrelevant question for the time being.

@calebrob6
Copy link
Member

To sketch out what I'm doing:

Load the task / model as in the current implementation
os.walk the predict_on dir to find every file that rasterio can load
for each file:
  create a RasterDataset using only that file, use the preprocess function from our input task's datamodule as a transform
  use GridSampler + a Dataloader to extract patches from the file
  put the predicted patches back together
  write result as GeoTIFF considering the .profile of the input file, assume outputs from the model are uint8 class values

@calebrob6
Copy link
Member

calebrob6 commented Jul 6, 2022

Here is a working implementation https://gist.github.com/calebrob6/17147f6a256486f08c0d912f40114fbf where I've taken a landcover.ai trained model and run it on an Inria test scene. Bonus: the predictions make some sense! I've annotated in TODOs where it needs some more flushing out.

Some example screenshots from QGIS (where I'm visualizing the imagery and outputs):

image

This is what the predictions look like with PATCH_SIZE = 256 where I'm not throwing away predicted pixels on the edges of each patch. Here you can see checkerboarding artifacts:

image

This is what the predictions look like with PATCH_SIZE = 512 where I'm throwing away 96 pixels along the edge of each predicted patch:

image

(naively this will leave a strip of 96 empty pixels along the edges of the larger tile -- another TODO would be not throwing away those predictions)

@ashnair1 @adamjstewart, let me know what you think!

@adamjstewart
Copy link
Collaborator

  1. Will it support images with and without georeference?
  2. Will it support running on a directory of images and producing a directory of predictions?
  3. Will it provide an option to merge the directory of predictions if they have geo-reference into a single static map?

I think it's important to distinguish needs vs. wants vs. dreams. Needs are things that should probably be supported before this PR is merged. Wants and dreams are for a future adventurous coder to implement, but I'm happy to discuss where I envision this script going in the future.

Needs: I think a basic requirement of predict.py would be that it can handle images with and without geographic metadata. Many of our builtin datasets lack geographic metadata. Since geographic metadata isn't actually being used in this script as far as I can tell, we shouldn't require images to have it.

Wants: I think the most common use case for this script would be for a user to point it to a directory of image files and make predictions on those files. But I think it would also be cool to support builtin Datasets and DataModules. Maybe this overlaps with evaluate.py though.

Dreams: I would love to download imagery for the entire Earth, stuff it in a directory, run predict.py on it, and get back a single prediction map for the entire Earth (or a country or continent or whatever).

Another important thing to think about is how all of these scripts play with each other. The way I see it, which may or may not be the way things currently are:

  • train.py: train a model, either supervised or self-supervised
  • evaluate.py: evaluate the performance of a model trained by train.py
  • predict.py: make predictions on a set of images using a model trained by train.py

Are we all on the same page with this, or is there disagreement? I also want the configs in conf.py to be close to state-of-the-art performance, not just random configs. Some day @calebrob6 and I want a leaderboard of sorts, similar to https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights.

All of the *.py files in the root of the project are currently not included in the code uploaded to PyPI. I've given some thought to adding them, but we still need to decide on an API and actually implement entrypoints (#228). Until then, we have the privilege of not caring about backwards compatibility or a stable API. Since this isn't included in the PyPI release, there's no rush to stuff this PR into 0.3.0, but I'm also fine with merging it (see below) and iterating on it later.

I still don't think PredictDataset is necessary for predict.py. It would probably be better to use RasterDataset or create a new ImageDataset for files that lack geographic metadata. @calebrob6 has a suggestion above that should work for many datasets, but in the long run we really need to finish #409 to make things faster with less distortion. In the meantime, if @calebrob6's suggestion doesn't work, maybe we can move PredictDataset to predict.py and refactor later? Just want to avoid adding an entirely new base class and removing or renaming it later.

@calebrob6
Copy link
Member

Are we all on the same page with this, or is there disagreement?

Yep, I agree with the needs/wants/dreams (although the "dream" seems way out of scope for us to implement in torchgeo). The code in the gist above satisfies the functionality of the "needs" and can run on a directory of images (part of the "wants").

I also want the configs in conf.py to be close to state-of-the-art performance...

Agree, but this doesn't seem very relevant to this PR.

I still don't think PredictDataset is necessary for predict.py...

My implementation doesn't use a PredictDataset concept. I search over a directory for everything that rasterio can read (the same way we do in RasterDataset) then create a single RasterDataset instance for each resulting file. E.g. you can point my script to the Inria test image directory and get predictions for each file. It actually doesn't make sense to create a single RasterDataset instance for the entire Inria test set as there are many 5000x5000 tiles that come from different places around the world -- a GridGeoSampler here (even solving the #409 no-reprojection issue) would return patches for you to run inference on but there wouldn't be any way to stitch the predictions back together into a bunch of individual 5000x5000 tiles.

@ashnair1
Copy link
Collaborator Author

ashnair1 commented Jul 6, 2022

Are we all on the same page with this, or is there disagreement?

Yup. I’ve always seen the scripts that way.

Thanks @calebrob6. Your solution looks like it might solve my current issue but I need to test it out first. The only issue I see is how to handle non-georeferenced images but I'm guessing we can switch to kornia's functions (ExtractTensorPatches and CombineTensorPatches) for that. Also need to address the TODOs.

Since this isn't included in the PyPI release, there's no rush to stuff this PR into 0.3.0, but I'm also fine with merging it (see below) and iterating on it later.

Agreed. No need to tie this to 0.3.0. I’ll take a look at @calebrob6’s solution and see if it works for my use cases.

I’ll be travelling this week so I’ll get back to this when I’m back.

@calebrob6
Copy link
Member

Good point about the non-georeferenced images (particularly, images that rasterio can't read).

If you don't mind, I can push a cleaned up version of what I have (i.e. with some of the TODOs fixed 😄) on this PR.

@calebrob6
Copy link
Member

The only issue I see is how to handle non-georeferenced images

Just did some testing and it looks like rasterio can handle PNGs and JPGs no problem so I think we can get away with using RasterDataset for everything!

@adamjstewart
Copy link
Collaborator

But if it can't read geospatial metadata I'm not sure how it will add them to the rtree index.

@calebrob6
Copy link
Member

It doesn't really need an rtree index as there will only be one file at a time. We really just want to re-use the logic from GridGeoSampler.

@calebrob6
Copy link
Member

calebrob6 commented Jul 6, 2022

The previous 4 5 commits clean up how our DataModules pre-process image inputs. Now, each DataModule has a preprocess(self, sample) method that performs the image preprocessing and doesn't break if there isn't a "mask" or "label" key in the sample. This is important as preprocess.py needs to be able to apply the same image pre-processing that a trained model checkpoint used while training, however the new data that the model is being run over will definitely not have any "mask" or "label" parts.

I also took this opportunity to clean up a few things throughout the DataModules -- I'll go through and explain these changes inline with comments.

conf/etci2021.yaml Outdated Show resolved Hide resolved
torchgeo/datamodules/chesapeake.py Outdated Show resolved Hide resolved
torchgeo/datamodules/chesapeake.py Outdated Show resolved Hide resolved
torchgeo/datamodules/chesapeake.py Outdated Show resolved Hide resolved
torchgeo/datamodules/cowc.py Outdated Show resolved Hide resolved
torchgeo/datamodules/eurosat.py Outdated Show resolved Hide resolved
torchgeo/datamodules/so2sat.py Outdated Show resolved Hide resolved
torchgeo/datamodules/ucmerced.py Outdated Show resolved Hide resolved
torchgeo/datamodules/usavars.py Outdated Show resolved Hide resolved
@adamjstewart
Copy link
Collaborator

Would it make sense to move all of the datamodule changes to a separate PR to merge first? It might be good to have that datamodule stuff in 0.3.0 and easier to review if it isn't all in one giant PR.

@calebrob6
Copy link
Member

Can I pull these commits out somehow?

@adamjstewart
Copy link
Collaborator

Yeah, just use git cherry-pick

@calebrob6
Copy link
Member

calebrob6 commented Jul 7, 2022

Cool, done, and I reset this branch back to ashnair's last commit.

"""Convert tuple of tensors to tuple of ints."""
return tuple(int(i.item()) for i in tensor_tuple)

original_shape = cast(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These shouldn't require a cast if everything is typed correctly.

@calebrob6
Copy link
Member

calebrob6 commented Sep 24, 2022

I may have fixed the mypy error. @adamjstewart can you give this another pass as it has been a minute?

@adamjstewart
Copy link
Collaborator

Last we left off, my only objection to this PR was PredictionDataset. If that can be removed and replaced with RasterDataset, then this will be good to go.

@ashnair1
Copy link
Collaborator Author

ashnair1 commented Sep 27, 2022

Recently came back to this. I incorporated Caleb's solution and wrote a new predict.py and it does indeed work for georeferenced images. This did require making preprocess a class method though.

For non-georeferenced images, this line will error out with the following error:

rasterio._err.CPLE_AppDefinedError: The transformation is already "north up" or a transformation between pixel/line and georeferenced coordinates cannot be computed for test_inria3/Raster2021_uint8_nogeo.tif. There is no affine transformation and no GCPs. Specify transformation option SRC_METHOD=NO_GEOTRANSFORM to bypass this check.

Another (perhaps minor) nit would be that it doesn't integrate well with Lightning's predict_dataloader. I was expecting the datamodule to handle the prepping of data based on the stage param and take advantage of respective dataloaders. Something like this

def setup(self, stage: Optional[str] = None) -> None:
    """Initialize the main ``Dataset`` objects.

    This method is called once per GPU per run.
    """
    if stage == "fit":
        # prepare train and val datasets
        self.train_dataset = ....
        self.val_dataset = ...
    if stage == "test"
        # prepare test dataset 
        self.test_dataset = .....
    if stage == "predict"
        # prepare the dataset representing inference directory
        self.predict_dataset = .....

@adamjstewart
Copy link
Collaborator

Where are we at with this PR? Some of the folks from IBM Research expressed interest in this feature.

@adamjstewart
Copy link
Collaborator

Just a heads up, we're trying to unify all pretrain/train/evaluate/predict scripts into one script in #1237

@adamjstewart
Copy link
Collaborator

Now that train.py supports both training and inference, I think we should close this PR. I think the best path forward is to modify predict_step in our trainers to store and stitch together predictions for each tile. Then we don't need PredictDataset or another script or anything else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets testing Continuous integration testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants