Averaging Weights leads to Wider Optima and Better Generalization
16 Jul 2020Introduction
-
The paper proposes Stochastic Weight Averaging (SWA) procedure for improving the generalization performance of models trained with SGD (with cyclic or constant learning rate).
-
Specifically, the model is checkpointed at several points along the training trajectory, and these checkpoints are averaged (in the parameter space) to obtain a single model.
Idea
-
“Stochastic” in the name refers to the idea that with cyclical or constant learning rate, SGD proposals are approximately sampled from a neural network’s loss surface and are hence stochastic.
-
SWA uses a learning rate schedule that allows exploration in the weight space.
-
SGD with cyclical and constant learning rates explore points (model instances) at the periphery of high-performing networks.
-
With different initializations, SGD will find different points (of low training loss) on this boundary, but will not move inside it.
-
Averaging the points provide a mechanism to move inside this periphery.
-
The train and the test error surfaces, while being similar, are not perfectly aligned. Hence, averaging several models (along the optimization trajectory) could lead to a more robust model.
Algorithm
-
Given a model w and some training budget B, train the model in the conventional way for approx 75% of the budget.
-
Starting from that point, continue training with the remaining budget, with a constant or cyclical learning rate.
-
For fixed learning rate, checkpoint models at each epoch. For cyclical learning rate, checkpoint the model at the lowest learning rate in the cycle.
-
Average all the models to get the SWA model.
-
If the model has Batch Normalization layers, run an additional pass to compute the SWA model’s running mean and standard deviation.
-
The computational and space complexity of computing the SWA model is relatively low.
-
The paper highlights the ensembling like the effect of SWA by showing that if the model checkpoints (wi) are generated by training with Fast Geometric Ensembling (FGE), the difference between averaging the weights and averaging the predictions is of the order O(Δ) where Δ=max||wi−wSA||.
-
Note that SWA does not have the overhead of an extra-forward pass during inference.
Experiments
-
Datasets: CIFAR10, CIFAR100, ImageNet
-
Models: VGG16, WideResNet, 164-layer preactivation ResNet, ShakeShake, Pyramid Net.
-
Baselines: Conventional SGD, Exponentially decaying average with SGD and FGE.
-
In all the CIFAR experiments, SWA consistently outperforms SGD in one budget and consistently improves with training.
-
SWA also achieves performance comparable to FGE, despite FGE being an ensemble method.
-
On ImageNet, SWA is run on a pre-trained model, and it improves performance in all the cases.
-
An ablation experiment (on CIFAR-100) shows that it is possible to train a network (with SWA) using a fixed learning rate. In that setup, using SWA improves performance by 16%.