Train and deploy a cat vs dog image recognition model using TensorFlow

leemengtw leemengtw Last update: Nov 10, 2023

Cat-recognition-train

This repository demonstrates how to train a cat vs dog recognition model and export the model to an optimized frozen graph easy for deployment using TensorFlow. If you want to know how to deploy a flask app which recognizes cats/dogs using TensorFlow, please visit cat-recognition-app.

Requirements

  • Python3 (Tested on 3.6.8)
  • TensorFlow (Tested on 1.12.0)
  • NumPy (Tested on 1.15.1)
  • tqdm (Tested on 4.29.1)
  • Dogs vs. Cats dataset from https://www.kaggle.com/c/dogs-vs-cats
  • (Optional if you want to run tests) PyTorch (Tested on 1.0.0 and 1.0.1)

Build environment

We recommend using Anaconda3 / Miniconda3 to manage your python environment.

If the machine you're using does not have a GPU instance, you can just:

$ pip install -r requirements.txt

or

$ conda install --file requirements.txt

However, if you want to use GPU to accelerate the training process, please visit TensorFlow - GPU support for more information.

Train a Convolutional Neural Network

In this part, we will use TensorFlow to train a CNN to classify cats' images from dogs' image using Kaggle dataset Dogs vs. Cats. We will do the following things:

  • Create training/valid set (dataset.py)
  • Load, augment, resize and normalize the images using tensorflow.data.Dataset api. (dataset.py)
  • Define a CNN model (net.py)
    • Here we use the ShufflenetV2 structure, which achieves great balance between speed and accuracy.
    • We do transfer learning on ShuffleNetV2 using the pretrained weights from https://github.com/ericsun99/Shufflenet-v2-Pytorch.
    • If you want to know how to load PyTorch weights onto TensorFlow model graph, please check convert_pytorch_weight_test starting from line 44 in module_tests.py.
  • Train the CNN model (train.py)
  • Serialize the model for deployment (train.py)

If you want to execute the code, make sure you have all package requirements installed, and Dogs vs. Cats training dataset placed in datasets. The folder structure should be like:

cat-recognition-train
+-- train.py
+-- net.py
+-- dataset.py
+-- datasets
    +-- train
    |   +-- cat.0.jpg
    |   +-- cat.1.jpg
    |   ...
    |   +-- cat.12499.jpg
    |   +-- dog.0.jpg
    |   +-- dog.1.jpg
    |   ...
    |   +-- dog.12499.jpg
+-- ...

After all requirements set, run the following command using default arguments:

$ python train.py

Or you can pass your desired arguments:

$ python train.py --epochs 30 --batch_size 32 --valset_ratio .1 --optim sgd --lr_decay_step 10

See train.py for available arguments.

Visualizing Learning using Tensorboard

During training, you can supervise how is the training going by running:

$ tensorboard --logdir runs

And you can check the tensorboard summaries on localhost:6006.

Training and Validation Flow

Whole training and validation flow, including CNN model and other training/validation operations like optimizer, saver, accuracy counter, etc

Model Performance

Validation/Train loss and validation accuracy on each epoch

Optimized Network Graph

Optimized Network Graph

Predict Using Optimized Frozen Graph

See predict.py for details and demo.

Default image used for predict.py demo

You can run

$ python predict.py

The result should be:

Predicting catness on images/test.png using model from baseline_model/optimized_net_best_acc.pb
Catness: 16.460064
Cat Probability: 1.000000
It's a cat.

for demonstration. Also, if you have your own cat / dog photo for testing, run

$ python predict.py --path path/to/your/img.png

PNGs, JPGs, BMPs are supported.

Subscribe to our newsletter