Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChangeDetectionTask #2422

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

Conversation

keves1
Copy link

@keves1 keves1 commented Nov 21, 2024

This PR is to add a change detection trainer as mentioned in #2382.

Key points/items to discuss:

  • I used the OSCD dataset to test with and modified this dataset to use a temporal dimension.
  • With the added temporal dimension, Kornia’s AugmentationSequential doesn’t work, but can be combined with VideoSequential to support the temporal dimension (see Kornia docs). I overrode self.aug in the OSCDDataModule to do this but not sure if this should be incorporated into the BaseDataModule instead.
  • VideoSequential adds a temporal dimension to the mask. Not sure if there is a way to avoid this, or if this is desirable, but I added an if statement to the AugmentationSequential wrapper to check for and remove this added dimension.
  • The OSCDDataModule applies _RandomNCrop augmentation, but this does not work for time series data. I'm not sure how to modify _RandomNCrop to fix this and would appreciate some help/guidance.
  • There are a few tests that I need to still make pass.

cc @robmarkcole

@github-actions github-actions bot added datasets Geospatial or benchmark datasets testing Continuous integration testing trainers PyTorch Lightning trainers transforms Data augmentation transforms datamodules PyTorch Lightning datamodules labels Nov 21, 2024
from .base import BaseTask


class FocalJaccardLoss(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should go in torchgeo/losses/ ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, I'll put this and XEntJaccardLoss there.

loss: 'ce'
model: 'unet'
backbone: 'resnet18'
in_channels: 13
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If stacking the channels should this be 2 * 13?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are no longer stacking channels, we are making all time series datasets (including change detection) into B x T x C x H x W

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK and it is not necessary to config T?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in_channels is multiplied by 2 in configure_models when initializing Unet.

labels: Optional labels to use for classes in metrics
e.g. ["background", "change"]
loss: Name of the loss function, currently supports
'ce', 'jaccard', 'focal' or 'focal-jaccard' loss.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and ce-jaccard ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I missed that, thanks.



class XEntJaccardLoss(nn.Module):
"""XEntJaccardLoss."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a paper ref for this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I'm aware of, and I didn't find one from a quick search. This loss function was already in there with what I started from.

"""
loss: str = self.hparams["loss"]
ignore_index = self.hparams["ignore_index"]
if loss == "ce":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noting that for change detection BCE may be more appropriate

#245

@robmarkcole
Copy link
Contributor

I wonder if we should limit the scope to change between two timesteps and binary change - then we can use binary metrics and provide a template for the plot methods. I say this because this is the most common change detection task by a mile. Might also simplify the augmentations approach? Treating as a video sequence seems overkill.
Also I understand there is support for multitemporal coming later.

@adamjstewart
Copy link
Collaborator

I wonder if we should limit the scope to change between two timesteps

I'm personally okay with this, although @hfangcat has a recent work using multiple pre-event images that would be nice to support someday (could be a subclass if necessary).

and binary change

Again, this would probably be fine as a starting point, although I would someday like to make all trainers support binary/multiclass/multilabel, e.g., #2219.

provide a template for the plot methods.

Could also do this in the datasets (at least for benchmark NonGeoDatasets). We're also trying to remove explicit plotting in the trainers: #2184

I say this because this is the most common change detection task by a mile.

Agreed.

Might also simplify the augmentations approach? Treating as a video sequence seems overkill.

I actually like the video augmentations, but let me loop in the Kornia folks to get their opinion: @edgarriba @johnnv1

Also I understand there is support for multitemporal coming later.

Correct, see #2382 for the big picture (I think I also sent you a recording of my presented plan).

@adamjstewart
Copy link
Collaborator

VideoSequential adds a temporal dimension to the mask. Not sure if there is a way to avoid this

Can you try keepdim=True?

I added an if statement to the AugmentationSequential wrapper to check for and remove this added dimension.

@ashnair1 would this work directly with K.AugmentationSequential now? We are trying to phase out our AugmentationSequential wrapper now that upstream supports (almost?) everything we need.

@keves1
Copy link
Author

keves1 commented Nov 22, 2024

I will go ahead and make changes for this to be for binary change and two timesteps, sounds like a good starting point.

Can you try keepdim=True?

I tried this and it didn't get rid of the other dimension. I also looked into extra_args but didn't see any options to help with this.

Could also do this in the datasets (at least for benchmark NonGeoDatasets). We're also trying to remove explicit plotting in the trainers

I was going to add plotting in the trainer, but would you rather not then? What would this look like in the dataset?

@robmarkcole
Copy link
Contributor

Perhaps there should even be a base class ChangeDetection and subclasses for BinaryChangeDetection etc?

@adamjstewart
Copy link
Collaborator

That's exactly what I'm trying to undo in #2219.

@adamjstewart
Copy link
Collaborator

I was going to add plotting in the trainer, but would you rather not then?

We can copy-n-paste the validation_step plotting stuff used by other trainers, but that's probably going to disappear soon (I think we're just waiting on testing in #2184.

What would this look like in the dataset?

See OSCD.plot()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets testing Continuous integration testing trainers PyTorch Lightning trainers transforms Data augmentation transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants