-
Notifications
You must be signed in to change notification settings - Fork 380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add prediction utilities #560
Conversation
@ashnair1 great image, do you mind if we use this in a blog to showcase the cool things TorchGeo can do? |
SemanticSegmentationTask, | ||
) | ||
|
||
TASK_TO_MODULES_MAPPING: Dict[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that this is in two different scripts, we should move it somewhere where both scripts can find it to avoid code duplication. How about a torchgeo/common.py
file? I don't want to put it in torchgeo/__init__.py
because this will be sourced on every import.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like the idea of torchgeo/common.py
. I can add that in a follow up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to this
predict.py
Outdated
|
||
|
||
def write_mask(mask: torch.Tensor, output_dir: str, input_filename: str) -> None: | ||
"""Write mask to specified output directory.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring needs descriptions for all inputs, same for other functions
conf = OmegaConf.load(os.path.join(config_dir, "experiment_config.yaml")) | ||
ckpt = os.path.join(config_dir, "last.ckpt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these filenames be parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ckpt filename sure. Think the config filename ("experiment_config.yaml") is hard coded.
help="Path to output_directory to save predicted mask geotiffs", | ||
) | ||
|
||
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this something we should let PyTorch Lightning handle?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps. LightningModules
shouldn't need this but I'll need to verify. Will get back to this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While LightningModules
are aware of which device they're on, the models (UNet, DeepLab etc) are not. Since we can't forward device info, this will be required.
|
||
if __name__ == "__main__": | ||
# Taken from https://github.com/pangeo-data/cog-best-practices | ||
_rasterio_best_practices = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can also be moved to torchgeo/common.py
) | ||
|
||
|
||
class PredictDataset(Dataset[Any]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you clarify why this is necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PredictDataset
represents a dataset of images in a folder that you want to run inference on. When you run predict.py
you need to pass a predict_on
flag. This can be either a test split of a dataset or a folder of images that you want to run inference on. PredictDataset
handles the latter.
It accepts a bands argument since we can't predict the number of bands the geotiffs will have and one transform (patch_sample
) that it will always do since we can't predict the size of the geotiff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm still not understanding. Why not use the normal data loader for that dataset (a subclass of either VisionDataset
or GeoDataset
)? RasterDataset
in combination with GridGeoSampler
(or just a normal VisionDataset
) is specifically designed to handle "a folder of images you want to run inference on". It seems like this won't work for all datasets since it only handles flat directories of images that can be opened with rasterio and don't need to be chipped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use the normal data loader for that dataset (a subclass of either
VisionDataset
orGeoDataset
)?
Because the dataloader of that dataset has custom logic for parsing filenames and subdirectories that is not relevant when you want to run inference on a directory of images. You could add code to the _load_files
method of each dataset that skips the parsing logic if it is in "predict" mode and I could see a case being made for that.
I just found having a separate class (PredictDataset) that represents a folder of images to be easier. Just need to instantiate it in the corresponding datamodule as seen in this PR.
It seems like this won't work for all datasets since it only handles flat directories of images that can be opened with rasterio and don't need to be chipped.
That's the exact use case being targeted. You have a bunch of images of arbitrary size in a folder and you want to produce inference results.
These images will of course need to be chipped to the size the model expects which is why the PredictDataset
contains a single transform that it will always apply - patch_sample
which extracts tensor patches according to the patch_size
the model was trained on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the dataloader of that dataset has custom logic for parsing filenames and subdirectories that is not relevant when you want to run inference on a directory of images.
But this doesn't cause any issues, right? Yes, you could use a simpler class for parsing a directory of images, but we already have datasets for everything. If there's no reason why those existing datasets don't work then I don't see why we need a new dataset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds like you want to train a model on Inria and then make predictions on a completely different dataset with similar images
Precisely. For example, one of the locations within the inria dataset is Austin. So if someone gives me some aerial imagery with similar resolution (30cm) over Dallas, Houston or even Austin, I want to be able to use this model to run inference.
That's fine, but maybe a name other than PredictDataset would be better.
Maybe PredictionFolder
? I'm open to suggestions.
I was under the mistaken impression that PredictDataset would be required for all datasets you want to use in predict.py.
Ah I see. PredictDataset
was meant to be a standalone class to be used when you want to run inference on a folder of images of arbitrary size and variable number of bands. You would inject it into the datamodule so that it can be used with predict.py
as seen in this PR. It's not connected to any dataset.
We already have RasterDataset and VectorDataset subclasses of GeoDataset that recursively locate files that can be opened with rasterio and fiona. Maybe we need a new ImageDataset subclass of VisionDataset that recursively locates files that can be opened with PIL?
Do we need PIL? I think we can stick with rasterio
but not use RasterDataset
. Geospatial info is not necessary for the prediction process but we don't want the process to fail if the image lacks geospatial info. Apart from throwing a NotGeoreferencedWarning
, rasterio
can read and write any image with/without geospatial info.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, didn't even notice that you were using rasterio. In that case, this class definitely isn't needed. You should definitely use RasterDataset
for this specific dataset. You might not need geospatial info if you only want to predict on one file, but as soon as you want to predict on multiple files, you need geospatial info to stitch together the predictions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the file in question does not have geospatial info then RasterDataset
can't be used. Also the patching and merging/stitching are done by kornia's ExtractTensorPatches
and CombineTensorPatches
.
Just so we're on the same page this is how I visualise the workflow:
The main inputs to the predict.py
script are a directory of images (inference folder) and the trained model. The directory of images contain images with/without georeference and with different sizes.
The script first runs over each image, patches/chips them to the appropriate size for the model, runs model inference and merges them back together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works fine if you want to convert a directory of images to a directory of predictions, but what if you want to predict a single static map of an entire country by stitching together those predictions? I think that's where predict.py
would really shine.
I still think the best approach is to use RasterDataset
for geospatial data and some kind of subclass of VisionDataset
for non-geospatial data. Also, these datasets could be used at training time for some kind of self-supervised learning, so PredictDataset
wouldn't be the best name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we're coming at this from two different angles.
From my POV, predict.py
was only intended to convert a directory of images to a directory of predictions. The motivation was simple: I have a model and I have images from certain areas of interest that I want to run inference on. I might want to merge them together but I could just use gdal_merge.py
for that.
The images could be spatially distributed i.e. AOI1.tif
, AOI2.tif
etc or temporally distributed i.e. AOI_2020.tif
, AOI_2019.tif
, AOI_2018.tif
.
Another place where this would be useful is when you want to upload predictions to an evaluation server. In Inria, the images are of size 5000 x 5000. For training I used 512 x 512 crops. But if I want to upload the predictions to an evaluation server they would expect predictions to be the same size as input i.e. 5000 x 5000.
Predicting a single static map or using it for self supervised learning is well outside my original scope.
In order to properly address this, there needs to some consensus on what prediction scenarios torchgeo
is interested in supporting:
- Will it support images with and without georeference?
- Will it support running on a directory of images and producing a directory of predictions?
- Will it provide an option to merge the directory of predictions if they have geo-reference into a single static map?
) | ||
patch_shape = cast(Tuple[int, int], tensor_to_int(batch["patch_shape"])) | ||
padding = cast(Tuple[int, int], tensor_to_int(batch["padding"])) | ||
patch_combine = CombineTensorPatches( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you tell me more about how this function works? What happens if your patches have overlap? Is this sufficient to close #30 or do we need something more powerful/generic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you tell me more about how this function works?
I've written a tutorial describing how extract and combine patches work here
What happens if your patches have overlap?
Currently you can extract patches with overlap (via the stride parameter) but you can't merge them together. This is because CombineTensorPatches
currently only supports stride=window_size
as seen here
Is this sufficient to close #30 or do we need something more powerful/generic?
Based on my understanding of the scope of #30, I would say no. If the goal is to just enable users to extract patches and stitch them together, kornia's ExtractTensorPatches
and CombineTensorPatches
are sufficient. Once CombineTensorPatches
supports stride!=window_size
, we will be able to handle patches with overlap.
But for the alternate stitching techniques (like label averaging) mentioned in the paper referenced in #30, we might need something more powerful as CombineTensorPatches
doesn't support this.
Not at all 👍 But I just want to mention that the image on the left is an image from the inria dataset (I think |
Which model did you use to produce those masks in the image? |
U-Net with resnet18 backbone. Think I used the exact same config as the one in |
647ef0e
to
4e7e47f
Compare
I'm going to pull the thread that is going on above down here so it's easier to follow: Part 1I think most of the discussion is nicely summarized by the three questions @ashnair1 is bringing up:
I think (point-by-point):
I think we can do this without adding a Part 2Separate to the above questions, there is the question of whether we should allow treating the test DataLoader of one of our existing |
To sketch out what I'm doing:
|
Here is a working implementation https://gist.github.com/calebrob6/17147f6a256486f08c0d912f40114fbf where I've taken a landcover.ai trained model and run it on an Inria test scene. Bonus: the predictions make some sense! I've annotated in TODOs where it needs some more flushing out. Some example screenshots from QGIS (where I'm visualizing the imagery and outputs): This is what the predictions look like with This is what the predictions look like with (naively this will leave a strip of 96 empty pixels along the edges of the larger tile -- another TODO would be not throwing away those predictions) @ashnair1 @adamjstewart, let me know what you think! |
I think it's important to distinguish needs vs. wants vs. dreams. Needs are things that should probably be supported before this PR is merged. Wants and dreams are for a future adventurous coder to implement, but I'm happy to discuss where I envision this script going in the future. Needs: I think a basic requirement of Wants: I think the most common use case for this script would be for a user to point it to a directory of image files and make predictions on those files. But I think it would also be cool to support builtin Datasets and DataModules. Maybe this overlaps with Dreams: I would love to download imagery for the entire Earth, stuff it in a directory, run Another important thing to think about is how all of these scripts play with each other. The way I see it, which may or may not be the way things currently are:
Are we all on the same page with this, or is there disagreement? I also want the configs in All of the I still don't think |
Yep, I agree with the needs/wants/dreams (although the "dream" seems way out of scope for us to implement in torchgeo). The code in the gist above satisfies the functionality of the "needs" and can run on a directory of images (part of the "wants").
Agree, but this doesn't seem very relevant to this PR.
My implementation doesn't use a |
Yup. I’ve always seen the scripts that way. Thanks @calebrob6. Your solution looks like it might solve my current issue but I need to test it out first. The only issue I see is how to handle non-georeferenced images but I'm guessing we can switch to kornia's functions (
Agreed. No need to tie this to 0.3.0. I’ll take a look at @calebrob6’s solution and see if it works for my use cases. I’ll be travelling this week so I’ll get back to this when I’m back. |
Good point about the non-georeferenced images (particularly, images that rasterio can't read). If you don't mind, I can push a cleaned up version of what I have (i.e. with some of the TODOs fixed 😄) on this PR. |
Just did some testing and it looks like |
But if it can't read geospatial metadata I'm not sure how it will add them to the rtree index. |
It doesn't really need an rtree index as there will only be one file at a time. We really just want to re-use the logic from |
The previous I also took this opportunity to clean up a few things throughout the DataModules -- I'll go through and explain these changes inline with comments. |
Would it make sense to move all of the datamodule changes to a separate PR to merge first? It might be good to have that datamodule stuff in 0.3.0 and easier to review if it isn't all in one giant PR. |
Can I pull these commits out somehow? |
Yeah, just use git cherry-pick |
Cool, done, and I reset this branch back to ashnair's last commit. |
"""Convert tuple of tensors to tuple of ints.""" | ||
return tuple(int(i.item()) for i in tensor_tuple) | ||
|
||
original_shape = cast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These shouldn't require a cast if everything is typed correctly.
I may have fixed the mypy error. @adamjstewart can you give this another pass as it has been a minute? |
Last we left off, my only objection to this PR was |
Recently came back to this. I incorporated Caleb's solution and wrote a new For non-georeferenced images, this line will error out with the following error:
Another (perhaps minor) nit would be that it doesn't integrate well with Lightning's def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
"""
if stage == "fit":
# prepare train and val datasets
self.train_dataset = ....
self.val_dataset = ...
if stage == "test"
# prepare test dataset
self.test_dataset = .....
if stage == "predict"
# prepare the dataset representing inference directory
self.predict_dataset = ..... |
Where are we at with this PR? Some of the folks from IBM Research expressed interest in this feature. |
Just a heads up, we're trying to unify all pretrain/train/evaluate/predict scripts into one script in #1237 |
Now that |
The main aim of this PR is to add utilities that will allow users to use a
torchgeo
segmentation model to create predictions onGeoTIFFsimages of arbitrary size.Example using a model trained on the
InriaAerialImageLabelling
dataset shown below. Images in this dataset are of size 5000 x 5000 and the model was trained on 512 x 512 crops.