Skip to content

Pytorch reimplementation of Spatial Transformer Networks and Inverse Compositional STN

Notifications You must be signed in to change notification settings

kamenbliznashki/spatial_transformer

Repository files navigation

Spatial Transformer Networks

Reimplementations of:

Although implementations already exists, this focuses on simplicity and ease of understanding of the vision transforms and model.

Results

During training, random homography perturbations are applied to each image in the minibatch. The perturbations are composed by component transformation (rotation, translation, shear, projection), the parameters of each sampled from a uniform(-1,1) * 0.25 multiplicative factor.

Example homography perturbation:
example perturbation

Test set accuracy:

Model Accuracy Training params
Basic affine STN 91.59% 10 epochs at learning rate 1e-3 (classifier and transformer)
Homography STN 93.30% 10 epochs at learning rate 1e-3 (classifier and transformer)
Homography ICSTN 97.67% 10 epochs at learning rate 1e-3 (classifier) and 5e-4 (transformer)

Sample alignment results:

Basic affine STN

Image Samples
original
perturbed
transformed
basic

Homography STN

Image Samples
original
perturbed
transformed
stn

Homography ICSTN

Image Samples
original
perturbed
transformed
icstn

Mean and variance of the aligned results (cf Lin ICSTN paper)

Mean image

Image Basic affine STN Homography STN Homography ICSTN
original
perturbed
transformed
basic stn icstn

Variance

Image Basic affine STN Homography STN Homography ICSTN
original
perturbed
transformed
basic stn icstn

Usage

To train model:

python train.py --output_dir=[path to params.json]
                --restore_file=[path to .pt checkpoint if resuming training]
                --cuda=[cuda device id]

params.json provides training parameters and specifies which spatial transformer module to use:

  1. BasicSTNModule -- affine transform localization network
  2. STNModule -- homography transform localization network
  3. ICSTNModule -- homography transform localization netwokr (cf Lin, ICSTN paper)

To evaluate and visualize results:

python evaluate.py --output_dir=[path to params.json]
                   --restore_file=[path to .pt checkpoint]
                   --cuda=[cuda device id]

Dependencies

  • python 3.6
  • pytorch 0.4
  • torchvision
  • tensorboardX
  • numpy
  • matplotlib
  • tqdm

Useful resources

About

Pytorch reimplementation of Spatial Transformer Networks and Inverse Compositional STN

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages