The main segmentation model is an EfficientUnet++ with a EfficientNet-b5 encoder.
The model was trained using a combination of three losses and a loss weighting scheme.
and ɑ (ramped on over the first 100 epochs) is defined as:
...
...
...
The following pytorch lightning
settings and training tricks were used:
-
batch_size
: 32 -
ADAM optimizer
withlearning_rate
of 0.0003 (altered using CosineAnnealingLR with T_max=10) - mixed precision training
- gradient clipping enabled (mode: norm, 0.5)
- stochastic weight averaging
- Data samples of size 4x256x256 (CxWxH; channels: R,G,B,NIR), normalized
- Data augmentation:
-HorizontalFlip
orVerticalFlip
, p=0.5
-RandomRotate90
, p=0.5
-RandomBrightnessContrast
, brightness_limit=0.2, contrast_limit=0.15, brightness_by_max=False
-Normalize
- Normalization: 4-channel mean/ std for all data from 2017-2020