Reimplementations of:
Although implementations already exists, this focuses on simplicity and ease of understanding of the vision transforms and model.
During training, random homography perturbations are applied to each image in the minibatch. The perturbations are composed by component transformation (rotation, translation, shear, projection), the parameters of each sampled from a uniform(-1,1) * 0.25 multiplicative factor.
Example homography perturbation:
Model | Accuracy | Training params |
---|---|---|
Basic affine STN | 91.59% | 10 epochs at learning rate 1e-3 (classifier and transformer) |
Homography STN | 93.30% | 10 epochs at learning rate 1e-3 (classifier and transformer) |
Homography ICSTN | 97.67% | 10 epochs at learning rate 1e-3 (classifier) and 5e-4 (transformer) |
Image | Samples |
---|---|
original perturbed transformed |
Image | Samples |
---|---|
original perturbed transformed |
Image | Samples |
---|---|
original perturbed transformed |
Image | Basic affine STN | Homography STN | Homography ICSTN |
---|---|---|---|
original perturbed transformed |
Image | Basic affine STN | Homography STN | Homography ICSTN |
---|---|---|---|
original perturbed transformed |
To train model:
python train.py --output_dir=[path to params.json]
--restore_file=[path to .pt checkpoint if resuming training]
--cuda=[cuda device id]
params.json
provides training parameters and specifies which spatial transformer module to use:
BasicSTNModule
-- affine transform localization networkSTNModule
-- homography transform localization networkICSTNModule
-- homography transform localization netwokr (cf Lin, ICSTN paper)
To evaluate and visualize results:
python evaluate.py --output_dir=[path to params.json]
--restore_file=[path to .pt checkpoint]
--cuda=[cuda device id]
- python 3.6
- pytorch 0.4
- torchvision
- tensorboardX
- numpy
- matplotlib
- tqdm