Super-Convergence with JUST PyTorch

Super-Convergence with JUST PyTorch


When creating Snaked, my snake classification model I needed to find a way to improve results. Super-Convergence was just that, a way to train a model faster whilst getting better results! HOWEVER, I found no guides on how to do it with the built-in PyTorch scheduler.

Learn the theory

Before you go through this you'd probably like to know what super-convergence is and how it works. The general gist is to increase the learning rate as much as possible at the beginning and then progressively decrease it at a cyclical rate. This is because larger learning rate's train faster, but cause the loss to diverge. My focus here is with PyTorch though, so I myself won't explain any further.

Here's a list of resources to delve deeper:


import torch
from torchvision import datasets, models, transforms
from import DataLoader

from torch import nn, optim
from torch_lr_finder import LRFinder

Setting Hyperparameters

Set transforms

transforms = transforms.Compose([
transforms.RandomResizedCrop(size=256, scale=(0.8, 1)),
    transforms.CenterCrop(size=224), #ImgNet standards
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # ImgNet standards

Load the data, model and basic hyper parameters

train_loader = DataLoader(datasets.CIFAR10(root="train_data", train=True, download=True, transform=transforms))
test_loader = DataLoader(datasets.CIFAR10(root="test_data", train=False, download=True, transform=transforms))

model = models.mobilenet_v2(pretrained=True)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters())

# Set the device in use to GPU (when it's available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model =

## Find the perfect learning rate
Note that doing this requires a separate library from [here](

lr_finder = LRFinder(model, optimizer, criterion, device)
lr_finder.range_test(train_loader, end_lr=10, num_iter=1000)
HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Stopping early, the loss has diverged
Learning rate search finished. See the graph with {finder_name}.plot()


Create a scheduler

Use the one cycle learning rate scheduler (for super-convergence).

Note that the scheduler uses the maximum learning rate from the graph. To choose look for the maximum gradient (slope) downwards.

The number of epochs to train for and the steps per epoch must be entered in. It is common practice to use the batch size as the steps per epoch.

scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 2e-3, epochs=50, steps_per_epoch=len(train_loader))

Train model

Train the model for 50 epochs. Print stats after every epoch (loss and accuracy).

Different schedulers should be called in different within the code. Placing the scheduler in the wrong place will cause bugs, so with the one-cycle policy ensure that the step method is called straight after each batch.

best_acc = 0
epoch_no_change = 0

for epoch in range(0, 50):
    print(f"Epoch {epoch}/49".format())

    for phase in ["train", "validation"]:
        running_loss = 0.0
        running_corrects = 0

        # PyTorch model's state must be changend
        # As layers like dropout work differently depending on state
        if phase == "train":
        else: model.eval()

        # Loop through the dataset
        for (inputs, labels) in train_loader:
            # Transfer data to the GPU
            inputs, labels =,

            # Reset the gradient (so the gradient doesn't accumilate)

            with torch.set_grad_enabled(phase == "train"):
                # Predict the label which the model gives the max probability (of being true)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                if phase == "train":
                    # Backprop

                    # Super convergence changes the learning rate

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds ==

        # Calculate and output metrics
        epoch_loss = running_loss / len(self.data_loaders[phase].sampler)
        epoch_acc = running_corrects.double() / len(self.data_loaders[phase].sampler)
        print("\nPhase: {}, Loss: {:.4f}, Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))

        # Stop the model from training further if it hasn't improved for 5 consecutive epochs
        if phase == "validation" and epoch_acc > best_acc:
            epoch_no_change += 1

            if epoch_no_change > 5:

Thanks for READING!

I hope this is easy enough to understand relatively quickly. As when I first implemented super-convergence it took me a long time to figure out how to use the scheduler (I couldn't find any code which utilized it). If you liked this blog post consider checking out other ways to improve your model. If you'd like to see how super-convergence is used in a real project, look no further than my snake classification project.

Cover image sourced here