This repository has repurposed the generative architecture of Razabi et al.'s Multi-Level Vector Quantized Variational AutoEncoder (VQ-VAE-2) to compress medical images in PyTorch.
Additionally, the compressed latent vectors and reconstructed images have been used to train the CheXNet (DenseNet-121 pre-trained on ImageNet) algorithm.
This repository supports two-level VQ-VAE (top and bottom hierachical layers). The vector quantization workflow is outline below.
Two levels of convolution based encoding captures both local (first-layer) and global (second-layer) features.
We converted our datasets into HDF5 formats as inputs for faster training. We used the MIMIC-CXR and the CheXpert datasets for training and external validation.
- python3
- PyTorch (torch)
- torchvision
- HDF5 (h5py)
- numpy
- tqdm
- matplotlib
- scikit-learn (sklearn)
We used HDF5 datasets to create and save padded images such that the training does not require pre-processing each time.
- To run the VQ VAE training script using the default hyperparameters:
python train_vqvae.py --data_path=[HDF5 TRAIN DATASET PATH] --save_path=[SAVE PATH]
- To increase the compression ratio, change the stride at each hierarchy with
first_stride
orsecond_stride
flags:
python train_vqvae.py --data_path=[HDF5 TRAIN DATASET PATH] --save_path=[SAVE PATH] --first_stride=4
Note: strides increase in multiples of 2: 2, 4, 8, 16
- To run the training of DenseNet-121 classifier:
python train_densenet.py --vqvae_file=[CHECK POINT FILE FROM ABOVE TRAINING]
The training of DenseNet-121 can be conducted with original images, latent vectors, or reconstructed images:
Note: checkpoint files can be found in the [save_path]/ceckpoints
directory from the training.
- Use the
test_model.ipynb
Jupyter Notebook to:
- create a padded image from any image file to fit into a square perspective ratio
- save reconstructed images from trained models
- calculate PSNR from saved images (from
create_images.py
)
Note: loading saved models require CUDA enabled devices. If the device does not have CUDA, load the file with:
torch.load('checkpoint.pt', location: 'cpu')
- To run the profiling code on DenseNet-121 training, comment out the
@profile
decorator inline 266
ofnetworks.py
. Once the training is complete, thepytorch_memblab
library will output the profiling info directly to terminal:
Line # Max usage Peak usage diff max diff peak Line Contents
===============================================================
266 @profile
267 def forward(self, input):
268 108.90M 164.00M 79.89M 118.00M if self.input_type == 'latent':
269 111.90M 164.00M 3.00M 0.00B input = self.init_conv(input) # convert to 3 channel input
270 770.49M 788.00M 658.59M 624.00M output = self.model(input)
271 770.49M 840.00M 0.00B 52.00M return output
- Loss curves are automatically generated in the
[save_path]
directory from the training.
- Reconstruction performance is satisfactory when evaluated with external datasets. In the example below, the algorithm trained with the CheXpert dataset (frontal view) and externally validated with the MIMIC-CXR dataset (both frontal and lateral views).
The trained model is robust to various input manipulations. Input image above, reconstructed image below:
- Classification performance of DenseNet-121 as determined by AUROC was satisfactory with the original and actually increased reconstructed, and compressed latent vector as input. We suspect that the VQ-VAE-2 is acting as a denoising autoencoder.
Download links for: saved models and original and reconstructed images from the validation MIMIC-CXR dataset
- Young Joon (Fred) Kwon MS |github|linkedin| MD PhD Student; Icahn School of Medicine at Mount Sinai
- G Anthony (Tony) Reina MD |github|linkedin| Chief AI Architect for Health & Life Sciences; Intel Corporation
- Ping Tak Peter Tang PhD |github|linkedin| Research Scientist; Facebook
- Eric K Oermann MD |github|linkedin| Instructor, Department of Neurosurgery; Director, AISINAI; Icahn School of Medicine at Mount Sinai
- Anthony B Costa PhD |github|linkedin| Assistant Professor, Department of Neurosurgery; Director, Sinai BioDesign; Icahn School of Medicine at Mount Sinai
This project is licensed under the APACHE License, version 2.0 - see the LICENSE.txt file for details
- MSTP T32 NIH T32 GM007280
- RSNA Medical Student Research Grant
- Intel Software and Services Group Research Grant