PyTorch 实现了 SWA(Stochastic Weight Averaging,随机加权平均),相比于传统的 SGD,使用 SWA 能够明显改善一些深度神经网络模型的测试精度(Test Accuracy)。而且,SWA 使用起来非常简单,能够加速模型训练,并提高模型的泛化能力。
SWA 基本原理
SWA 依赖两个重要的因素:
第一个是,SWA 使用一个不断修改的 LR 调节器(Learning Rate Schedule),使得 SGD 能够在最优值附近进行调整,并评估最优解附近的值对应的模型的精度,而不是只选取最优解对应的模型。因为,最优解对应的模型不一定是最优的,而且泛化能力可能也不一定最好。比如,在 75% 的训练时间里,可以使用一个标准的衰减学习率(Decaying Learning Rate)策略,然后在剩余 25% 的训练时间里将学习率设置为一个比较高的固定值。如下图所示:
第二个是,SWA 计算的是 SGD 遍历过的神经网络权重的平均值。例如,上面提到模型训练的后 25% 时间,我们可以在这 25% 时间里的每一轮训练(every epoch)后,计算一个权重的 running 平均值,在训练结束后再设置网络模型的权重为 SWA 权重平均值。
SWA 论文提供了对其算法的过程的描述,如下图所示:
上面算法描述,给出了下面几个重要的内容:
- 学习率在模型训练过程中是不断调整的,同时更新 SGD 梯度值
- 基于模型训练的迭代次数,设置在训练过程中不同的阶段(位置)更新两个不同的权重平均值
- 在模型训练结束后,为 SWA 权重计算 BatchNorm 统计量
SWA 使用要点
在 PyTorch 中,使用 SWA 训练模型的基本要点,描述如下:
- 创建 SWA 模型
创建 SWA 模型,直接使用 PyTorch 提供的 AveragedModel,传入上面创建的 model,其中 model 可以是任意继承了 torch.nn.Module 的模型实现:
swa_model = AveragedModel(model)
这样 swa_model 会在模型训练过程中持续跟踪 model 的参数的平均值。如果希望更新这个 running 平均值,需要在 optimizer.step() 之后调用 update_parameters() 函数:
swa_model.update_parameters(model)
- 使用 SWALR
使用 SWA 时,通常配合使用 SWALR 这个 Learning Rate Scheduler 来退火(anneals)至一个固定的常量值,然后一直保持不变,使用示例:
swa_scheduler = torch.optim.swa_utils.SWALR( optimizer, anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05 ) # 或者 swa_scheduler = torch.optim.swa_utils.SWALR( optimizer, anneal_strategy="cos", anneal_epochs=5, swa_lr=0.05 )
- 计算模型在 DataLoader 上的 BN 统计量
在模型训练结束之后,需要进行计算 BN(Batch Normalization)统计量并更新模型:
torch.optim.swa_utils.update_bn(train_dataloader, swa_model)
SWA 编程实践
下面,我们基于上面提到的使用 SWA 的方法训练模型,通过实际编程来加强理解:
1 准备数据集并定义模型
我们使用 MNIST 数据集,神经网络模型使用 LeNet-5,代码处理逻辑如下所示:
import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor # Download training/test data from open datasets. train_dataset = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor(),) test_dataset = datasets.MNIST(root="data", train=False, download=False, transform=ToTensor(),) print(f"train_dataset_size = {len(train_dataset)}, test_dataset_size = {len(test_dataset)}") batch_size = 64 # Create data loaders. train_dataloader = DataLoader(train_dataset, batch_size=batch_size) test_dataloader = DataLoader(test_dataset, batch_size=batch_size) for X, y in test_dataloader: print(f"Shape of X [N, C, H, W]: {X.shape}") print(f"Shape of y: {y.shape} {y.dtype}") break # Get cpu, gpu or mps device for training. device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"Using {device} device") # Define model class LeNet5Model(nn.Module): def __init__(self): super().__init__() self._conv = nn.Sequential( nn.Conv2d(1, 6, 5, 1), nn.MaxPool2d(2), nn.Conv2d(6, 16, 5, 1), nn.MaxPool2d(2) ) self._fc = nn.Sequential( nn.Linear(4*4*16, 120), nn.Linear(120, 84), nn.Linear(84, 10) ) def forward(self, x): x = self._conv(x) x = x.view(-1, 4 * 4 * 16) x = self._fc(x) return x
代码比较容易,不再赘述。
2 使用 SWA 训练模型
首先,我们创建模型,并指定要使用的 Loss Function 和 Optimizer,实现代码如下所示:
# Create model model = LeNet5Model().to(device) print(model) # Define loss function and optimizer loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
使用 SWA,直接在以往的编程基础上,复用上面创建的一些对象即可。
下面,实现基于 SWA 训练模型的核心逻辑代码,如下所示:
from torch.optim.swa_utils import AveragedModel, SWALR from torch.optim.lr_scheduler import CosineAnnealingLR def swa_train(epoch, train_loader, model, loss_fn, optimizer, swa_start): # scheduler = CosineAnnealingLR(optimizer, T_max=10) scheduler = SWALR(optimizer, anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05) swa_scheduler = SWALR(optimizer, swa_lr=0.05) swa_model = AveragedModel(model) size = len(train_loader.dataset) for batch, (X, y) in enumerate(train_loader): optimizer.zero_grad() loss = loss_fn(model(X), y) loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), (batch + 1) * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") if epoch > swa_start: swa_model.update_parameters(model) swa_scheduler.step() else: scheduler.step() return swa_model
在模型训练过程中,使用 swa_start 控制启动 SWA 更新模型参数,并使用 swa_scheduler 来调节学习率,否则就使用默认的学习率调节器控制训练过程。
现在,我们就可以调用上面提供的 SWA 模型训练函数 swa_train(),训练我们前面定义的 LeNet 模型,代码如下:
epochs = 10 swa_start = 7 for epoch in range(epochs): print(f"Epoch {epoch + 1}\n-------------------------------") swa_model = swa_train(epoch, train_dataloader, model, loss_fn, optimizer, swa_start) # Update bn statistics for the swa_model at the end torch.optim.swa_utils.update_bn(train_dataloader, swa_model) test(test_dataloader, swa_model, loss_fn)
在调用 swa_train() 函数训练的迭代过程中,并没有对模型进行预测调用,所以 Batch Norm 层也就没有计算过神经网络中这些激活统计量。在模型训练结束后,需要使用的训练数据对 swa_model 模型进行一次 forward 计算并更新这些统计量的值。从上面代码可以看到,BatchNorm 层在模型训练期结束后会调用函数 update_bn() 来计算激活统计量,并更新模型。
上面的 test() 函数,实现了使用测试集 test_dataloader 来计算 loss 值和测试精度,对应代码如下:
def test(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
运行上面代码,可以执行模型训练,训练过程输出信息示例如下:
Epoch 1 ------------------------------- loss: 0.034567 [ 64/60000] loss: 0.062562 [ 6464/60000] loss: 0.085734 [12864/60000] loss: 0.029585 [19264/60000] loss: 0.005262 [25664/60000] loss: 0.067925 [32064/60000] loss: 0.108129 [38464/60000] loss: 0.027133 [44864/60000] loss: 0.205612 [51264/60000] loss: 0.077616 [57664/60000] Epoch 2 ------------------------------- loss: 0.010665 [ 64/60000] loss: 0.066350 [ 6464/60000] loss: 0.074224 [12864/60000] loss: 0.042965 [19264/60000] loss: 0.005810 [25664/60000] loss: 0.064340 [32064/60000] loss: 0.106495 [38464/60000] loss: 0.023883 [44864/60000] loss: 0.189915 [51264/60000] loss: 0.081894 [57664/60000] Epoch 3 ... ... Epoch 9 ------------------------------- loss: 0.010908 [ 64/60000] loss: 0.073716 [ 6464/60000] loss: 0.044475 [12864/60000] loss: 0.083066 [19264/60000] loss: 0.019485 [25664/60000] loss: 0.056779 [32064/60000] loss: 0.118609 [38464/60000] loss: 0.020461 [44864/60000] loss: 0.145578 [51264/60000] loss: 0.121390 [57664/60000] Epoch 10 ------------------------------- loss: 0.029301 [ 64/60000] loss: 0.065231 [ 6464/60000] loss: 0.043666 [12864/60000] loss: 0.064136 [19264/60000] loss: 0.015114 [25664/60000] loss: 0.054391 [32064/60000] loss: 0.124274 [38464/60000] loss: 0.018437 [44864/60000] loss: 0.135725 [51264/60000] loss: 0.110678 [57664/60000] Test Error: Accuracy: 98.5%, Avg loss: 0.043567
3 使用 SWA 模型
使用模型,示例代码如下所示:
classes = [ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot", ] swa_model.eval() for i in range(10): start = 13 * i + 8 x, y = test_dataset[start][0], test_dataset[start][1] with torch.no_grad(): x = x.to(device) pred = swa_model(x) predicted, actual = classes[pred[0].argmax(0)], classes[y] print(f'Predicted: "{predicted:<12s}", Actual: "{actual:<12s}"')
可以看到,使用模型预测的结果示例:
Predicted: "Sandal ", Actual: "Sandal " Predicted: "Shirt ", Actual: "Shirt " Predicted: "Sneaker ", Actual: "Sneaker " Predicted: "Pullover ", Actual: "Pullover " Predicted: "Sneaker ", Actual: "Sneaker " Predicted: "Ankle boot ", Actual: "Ankle boot " Predicted: "Sneaker ", Actual: "Sneaker " Predicted: "Ankle boot ", Actual: "Ankle boot " Predicted: "Dress ", Actual: "Dress " Predicted: "Ankle boot ", Actual: "Ankle boot "
总结
我们从使用 SWA 能为我们带来的优势出发,并结合 SWA 论文实验给定的一些结论,对 SWA 进行总结:
- 使用 SWA 非常简单,有效提高模型训练性能
SWA 使用了一个可以在训练模型过程中不断修改的学习率调节器(Learning Rate Schedule),能够更快地到达最优解附近区域。而且,模型训练的时候,降低迭代次数(epochs)也能够很快得到比较高的精度。
- 得到更合适的模型参数
在 SWA 论文中,使用预激活(Preactivation) ResNet-164 模型,在 CIFAR-100 数据集上分别使用 SWA 和 SGD 训练模型,得到结果如下图所示:
上图中表明,使用 SGD,训练模型得到的 Train Loss 是最优的,但是对应的 Test Error 并不是最优的,说明这个最优值可能是一个局部最优值;而使用 SWA 得到的 Train Loss 不是最优的,但是对应的 Test Error 却是最优的,说明 SWA 得到的 Train Loss 是一个比 SGD 更优的全局解。
- 更好的泛化性能
从泛化性能方面来考虑,对比 SGD 和 SWA,如下图所示:
使用 SGD,得到的 Train Loss 是一个最优值,但是位于边界位置上,所以对应的 Test Error 变化幅度就会相对更大,从图中可以看到 Train Loss 是最优值但对应的 Test Error 并不是。
使用 SWA,能够得到一个解,它使 Train Loss 集中在一个足够宽的平滑区域范围内,而在这个区域内得到的对应 Test Error 是全局最优解,这样就能更有潜力使获得的模型具有更好的泛化性能。
- 使用 SWA 应用范围比较广泛
SWA 不仅可以用于 SGD 优化器,还可以用于其他一些优化器,比如 Adam。
SWA 还有其他一些扩展,如 SWAG、MultiSWAG、SWALP、SWAP,在对应的应用领域内,都能够得到比较好的效果。
另外,对于基于大模型的预训练场景,鉴于 SWA 的加快模型训练速度的优势,也能够更加广泛地被应用。
参考资源
- https://pytorch.org/docs/stable/optim.html#weight-averaging-swa-and-ema
- https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
- Averaging Weights Leads to Wider Optima and Better Generalization
本文基于署名-非商业性使用-相同方式共享 4.0许可协议发布,欢迎转载、使用、重新发布,但务必保留文章署名时延军(包含链接:http://shiyanjun.cn),不得用于商业目的,基于本文修改后的作品务必以相同的许可发布。如有任何疑问,请与我联系。