Improving your computer vision model

Subscribe to my newsletter and never miss my upcoming articles

Why?

So you've cleaned your data, written some basic code to train a model, but now don't know where to go next. Don't worry, I've got your back. I'm going to explain in as much detail as I can the tricks I've learnt about which can help improve any model.

Data Augmentation

Data augmentations are modifications you can make to your input data. They can help make your model more robust, and able to generalise better. There is a wide variety available, but I'll just describe some of the one's I've tried:

  • Randomly crop 80%-100% of each image
  • Adjustment the aspect ratio
  • Random colour jitter
  • Resize images to have 256-pixel height
  • Centre crop the image to 224x224 pixels

Super-Convergence

The point of super-convergence is to speed up training, whilst also improving performance (a win-win situation). It works based on the idea that higher learning rates train models fast and can act as regularizers. It decreases the learning rate with time in a cyclic fashion. Note that you can try out the AdamW optimizer as well, as it's supposed to give better results. I'm writing a whole article on how to use super-convergence in pure PyTorch, so if interested make sure to check that out! For the nitty-gritty details take a look at the original paper.

Learning rate finder/tuning hyperparameters

This one is usually done in combination with super-convergence, but can also be used itself. The idea is to plot a graph of learning rate vs loss. In this way, you can find the maximum learning rate you can safely use (without the gradient diverging and loss increasing).

Note that whilst training it does need to be decreased, or else loss will increase again (super-convergence does this for you).

Test time augmentation

Test time augmentation involves averaging the results of a model over several augmented images. This can yield higher results, however, I decided against it due to unusually high VRAM usage with the PyTorch library I found (easier with Fast.AI though).

Balance dataset

This one is REALLY important. I originally didn't notice this, but someone who looked at my original code saw that I hadn't balanced the number of images present per class. This meant that classes with fewer images would be significantly less likely to be predicted. The reason for this is that classes with more images have a much higher contribution to loss functions (standard ones at the least).

The best way to deal with this is to get more data, however in many problems, this just can't be done.

The first simple alternative is simply oversampling but can lead to overfitting. Another way is undersampling the majority classes, but this may discard important samples! Instead, try out a newer loss function like LDAM Loss.

Confusion Matrices

I'm not an expert here but in essence, the confusion matrix (from my understanding) can be used to find out which classes aren't being classified well. From here you can analyse those classes, and try to see why the model isn't doing so well and then try and improve it (i.e. maybe more images are needed of one class).

Progressive resizing

This one's simple and effective. You start with training your model with small images of low resolutions, and then progressively increase it. The reason this increases model robustness is that the model is forced to look for simple patterns before complex ones.

Consider metrics other than accuracy

Accuracy may not be the best metric for your problem. Metrics like F1 score can be equally if not more useful!

THANKS FOR READING!

Now that you've heard me ramble, I'd like to thank you for taking the time to read through my blog (or skipping to the end). If this has helped you out consider checking out my article on problems I encountered whilst building my first project!

Cover image sourced here

No Comments Yet