Skip to content

yufanxin/Torch_Classify

Repository files navigation

深度学习在图像分类中的应用

前言

起因:因为我看github上面很多其他计算机视觉任务的集成,都写得很好了,但是分类这块,一直没找到我想要的那种,索性自己整理了一个符合自己需求的。以后也会陆续添加模型。

  • 本教程是对本人本科生期间的研究内容进行整理总结,总结的同时也希望能够帮助更多的小伙伴。后期如果有学习到新的知识也会与大家一起分享。
  • 本教程使用Pytorch进行网络的搭建与训练。
  • 本教程参考的链接附在最后。感谢大家的支持。

目前

2021.08.05

  • 纠正了环境配置
  • 增加了Focal Loss
  • 增加了tools中可视化的工具
  • 优化了记录每次exp的result及plot
  • 修复了predict.py读取class_indices.json的BUG
  • 修复了adamw优化器和plot_lr冲突的BUG

2021.07.29

  • 纠正了环境配置
  • 增加了tools中创建数据的工具

2021.07.28

  • 增加了ResMlp-Mixer VoVNet se-resnet SqueezeNet MnasNet模型
  • 增加了tools中转换权重、计算模型参数、模型FPS、模型吞吐量的工具
  • 更新了权重加载方式和权重链接
  • 优化了logs文件夹,记录每次exp的config,添加requirements.txt并纠正环境配置
  • 修复了warmup_epoch=0的BUG

2021.07.25

  • first commit

待做

  • TensorRT加速
  • Swin-Transformer

支持模型

#  --------------------------------------------------------------------------------------
# |model_prefix    |model_suffix                                                         |
# |--------------------------------------------------------------------------------------|
# |vgg             |11 13 16 19 bn11 bn13 bn16 bn19                                      |
# |--------------------------------------------------------------------------------------|
# |resnet          |18 34 50 101 152                                                     |
# |--------------------------------------------------------------------------------------|
# |resnext         |50-32x4d 101-32x8d                                                   |
# |--------------------------------------------------------------------------------------|
# |regnetx         |200mf 400mf 600mf 800mf 1.6gf 3.2gf 4.0gf 6.4gf 8.0gf 12gf 16gf 32gf |
# |--------------------------------------------------------------------------------------|
# |regnety         |200mf 400mf 600mf 800mf 1.6gf 3.2gf 4.0gf 6.4gf 8.0gf 12gf 16gf 32gf |
# |--------------------------------------------------------------------------------------|
# |mobilenetv2     |0.25, 0.5, 0.75, 1.0, 1.25, 1.5                                      |
# |--------------------------------------------------------------------------------------|
# |mobilenetv3     |small large                                                          |
# |--------------------------------------------------------------------------------------|
# |ghostnet        |0.5 1.0 1.3                                                          |
# |--------------------------------------------------------------------------------------|
# |efficientnetv1  |b0 b1 b2 b3 b4 b5 b6 b7                                              |
# |--------------------------------------------------------------------------------------|
# |efficientnetv2  |small medium large                                                   |
# |--------------------------------------------------------------------------------------|
# |shufflenetv2    |0.5 1.0 1.5 2.0                                                      |
# |--------------------------------------------------------------------------------------|
# |densenet        |121 161 169 201                                                      |
# |--------------------------------------------------------------------------------------|
# |xception        |299                                                                  |
# |--------------------------------------------------------------------------------------|
# |vit             |base-patch16 base-patch32 large-patch16 large-patch32 huge-patch14   |
#  --------------------------------------------------------------------------------------
# |resmlp-mixer    |12 24 36 B24                                                         |
#  --------------------------------------------------------------------------------------
# |vovnet          |27slim 39 57                                                         |
#  --------------------------------------------------------------------------------------
# |se-resnet       |18 34 50 101 152                                                     |
#  --------------------------------------------------------------------------------------
# |squeezenet      |1.0 1.1                                                              |
#  --------------------------------------------------------------------------------------
# |mnasnet         |0.5 0.75 1.0 1.3                                                     |
#  --------------------------------------------------------------------------------------

训练准备

数据格式

# -data
#    -train
#       -class_0
#          -1.jpg
#       -class_1
#       -...
#    -val
#       -class_0
#       -class_1
#       -...

环境配置

  • Anaconda3
  • python 3.8
  • pycharm (IDE, 建议使用)
  • pytorch 1.8.1
  • apex 0.1.0
  • VS2019
  • Cuda10.2

安装Pytorch

conda create -n classify python=3.8
conda activate classify
conda install pytorch=1.8 torchvision cudatoolkit=10.2 -c pytorch

安装apex (非必须,若不需要则config.py中use_apex=False)

cd apex-master
pip install -r requirements.txt
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .

注意语句最后的点也要复制

训练

在config.py中修改你想要的模型配置,注意,我的代码中,每个模型有2部分组成,分别是model_prefix和model_suffix。

例如

model_prefix='shufflenetv2'
model_suffix='0.5'

为了方便大家,我写了关于参数用途的注释。配置好之后运行train.py

可视化相关指标

训练完成之后在你的log_dir中查看训练过程。

预测

  • 我只写了单张图片的预测,但是你可以在我的基础上很灵活的更改成适合你项目需求的预测代码。

  • 同样的,在config.py中修改load_from,predict_img_path,注意这里img_path不再有效,因为img_path只针对训练。

模型权重

参考

  1. https://github.com/pytorch/vision/tree/master/torchvision/models
  2. https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_classification
  3. https://github.com/rwightman/pytorch-image-models/tree/master/timm/models
  4. https://github.com/yizt/Grad-CAM.pytorch

联系方式

  1. QQ:2267330597
  2. E-mail:201902098@stu.sicau.edu.cn

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages