English | 简体中文
PaddlePaddle training/validation code and pretrained models for GAN.
This implementation is part of PaddleViT project.
Update (2021-08-25): Init readme uploaded.
The following links are provided for the code and detail usage of each model architecture:
This module is tested on Python3.6+, and PaddlePaddle 2.1.0+. Most dependencies are installed by PaddlePaddle installation. You only need to install the following packages:
pip install yacs pyyaml lmdb
Then download the github repo:
git clone https://github.com/xperzy/PPViT.git
cd PPViT/image_classification
Cifar10, STL10, Celeba and LSUNchurch datasets are used in the following folder structure:
We use paddle.io.Dataset.Cifar10
to crate the Cifar10 dataset, download and prepare the data manually is NOT needed.
│STL10/
├── train_X.bin
│── train_y.bin
├── test_X.bin
│── test_y.bin
│── unlabeled.bin
│Celeba/
├──img_align_celeba/
│ ├── 000017.jpg
│ │── 000019.jpg
│ ├── 000026.jpg
│ │── ......
│LSUNchurch/
├──church_outdoor_train_lmdb/
│ ├── data.mdb
│ │── lock.mdb
For specific model example, go to the model folder, download the pretrained weight file, e.g., ./cifar10.pdparams
, to use the styleformer_cifar10
model in python:
from config import get_config
from generator import Generator
# config files in ./configs/
config = get_config('./configs/styleformer_cifar10.yaml')
# build model
model = Generator(config)
# load pretrained weights
model_state_dict = paddle.load('./cifar10.pdparams')
model.set_dict(model_state_dict)
To generate sample images from pretrained models, download the pretrained weights, and run the following script using command line:
sh run_generate.sh
or
python generate.py \
-cfg=./configs/styleformer_cifar10.yaml \
-num_out_images=16 \
-out_folder=./images_cifar10 \
-pretrained=/path/to/pretrained/model/cifar10 # .pdparams is NOT needed
The output images are stored in -out_folder
path.
🤖 See the README file in each model folder for detailed usages.
PaddleViT image classification module is developed in separate folders for each model with similar structure. Each implementation is around 3 type of classes and 2 types of scripts:
-
Model classes such as ViT_custom.py, in which the core transformer model and related methods are defined.
-
Dataset classes such as dataset.py, in which the dataset, dataloader, data transforms are defined. We provided flexible implementations for you to customize the data loading scheme. Both single GPU and multi-GPU loading are supported.
-
Config classes such as config.py, in which the model and training/validation configurations are defined. Usually, you don't need to change the items in the configuration, we provide updating configs by python
arguments
or.yaml
config file. You can see here for details of our configuration design and usage. -
main scripts such as main_single_gpu.py, in which the whole training/validation procedures are defined. The major steps of training or validation are provided, such as logging, loading/saving models, finetuning, etc. Multi-GPU is also supported and implemented in separate python script
main_multi_gpu.py
. -
run scripts such as run_eval_cifar.sh, in which the shell command for running python script with specific configs and arguments are defined.
PaddleViT now provides the following transfomer based models:
- TransGAN (from Seoul National University and NUUA), released with paper TransGAN: Two Pure Transformers Can Make One Strong GAN, and That Can Scale Up, by Yifan Jiang, Shiyu Chang, Zhangyang Wang.
- Styleformer (from Facebook and Sorbonne), released with paper Styleformer: Transformer based Generative Adversarial Networks with Style Vector, by Jeeseung Park, Younggeun Kim.
If you have any questions, please create an issue on our Github.