This project demonstrates an advanced image classification model using transfer learning with a pre-trained ResNet50 model. The model is trained on the CIFAR-10 dataset, which consists of 60,000 32x32 color images in 10 different classes. This project covers various stages, including data preprocessing, model training, evaluation, and deployment using a Flask web application.
- Data Preprocessing and Augmentation: Includes advanced data augmentation techniques to enhance the training dataset.
- Transfer Learning: Utilizes a pre-trained ResNet50 model for feature extraction and adds custom layers for classification.
- Model Evaluation: Detailed evaluation using metrics such as accuracy, precision, recall, F1-score, and confusion matrix.
- Model Explainability: Visualizes which parts of an image contribute most to the model's predictions using Grad-CAM.
- Deployment: Deploys the trained model using a Flask web application for real-time image classification.
To get started, clone the repository and install the necessary dependencies.
git clone https://github.com/SlimenFellah/Image-Classification-Model-with-Transfer-Learning.git
cd image-classification-with-transfer-learning
pip install -r requirements.txt
Perform EDA to understand the dataset better and visualize the images and class distribution.
python eda.py
Train the image classification model using the CIFAR-10 dataset. This step includes data augmentation, transfer learning, and hyperparameter tuning.
python model.py
Evaluate the trained model using various metrics and visualize the results.
python evaluation.py
Deploy the model using a Flask web application. This allows you to upload images and get real-time classification results.
python deployment.py
The Flask app will be available at http://127.0.0.1:5000
.
-
Start the Flask application:
python deployment.py
-
Use a tool like
curl
or Postman to send a POST request to the/predict
endpoint with an image file.curl -X POST -F "file=@path/to/your/image.jpg" http://127.0.0.1:5000/predict
You should receive a JSON response with the predicted class of the image.
main.py
: Main script to run the entire project.eda.py
: Script for exploratory data analysis.model.py
: Script to define and train the model.evaluation.py
: Script for model evaluation and visualization.deployment.py
: Script to deploy the model using Flask.requirements.txt
: File containing the list of dependencies.
- torch
- torchvision
- Flask
- numpy
- matplotlib
- seaborn
- scikit-learn
- Pillow
This project utilizes the CIFAR-10 dataset, obtained from the torchvision library, and the ResNet50 model pre-trained on ImageNet, also from torchvision.
This project is licensed under the MIT License - see the LICENSE file for details.