This contains the code for training a multi-class image classifier using transfer learning for VGG16, ResNet50, Inception ResNet V2, MobileNet V2 with tensorflow.
data_loader.py: Contains the code for loading the dataset from a folder structured asclass_name/*.jpg.models.py: Class for tensorflow model loading.train.py: Code for training the model.test.py: Code for testing the model.
The dataset used for training and the pre-trained models can be downloaded from here.
Create a conda enviroment
conda create -n image_classifier python=3.6.3
conda activate image_classifierInstall the required libraries
pip3 install -r requirements.txtUnzip the downloaded dataset in ./dataset folder and in the config.yaml set the required parameters.
python train.py --config_path <path to config file>This repository currently supports the following models:
VGG16ResNet50Inception Resnet V2Mobilenet V2
python test.py --config_path <path to config file> The output of test will be stored in ./figures folder
| Model Name | VGG16 | ResNet50 | Inception ResNet V2 | MobileNet V2 |
|---|---|---|---|---|
| Accuracy (%) | 88.89 | 33.33 | 88.89 | 78 |
The Inception-Resnet-v2 (although similar performance to VGG16) can be considered the best due to its computational efficiency owing to it being lightweight and its ability to train on a single GPU (less resource intensive).