使用 PyTorch SWA 优化模型训练入门实践

PyTorch 实现了 SWA(Stochastic Weight Averaging,随机加权平均),相比于传统的 SGD,使用 SWA 能够明显改善一些深度神经网络模型的测试精度(Test Accuracy)。而且,SWA 使用起来非常简单,能够加速模型训练,并提高模型的泛化能力。

SWA 基本原理

SWA 依赖两个重要的因素:
第一个是,SWA 使用一个不断修改的 LR 调节器(Learning Rate Schedule),使得 SGD 能够在最优值附近进行调整,并评估最优解附近的值对应的模型的精度,而不是只选取最优解对应的模型。因为,最优解对应的模型不一定是最优的,而且泛化能力可能也不一定最好。比如,在 75% 的训练时间里,可以使用一个标准的衰减学习率(Decaying Learning Rate)策略,然后在剩余 25% 的训练时间里将学习率设置为一个比较高的固定值。如下图所示:
SWA
第二个是,SWA 计算的是 SGD 遍历过的神经网络权重的平均值。例如,上面提到模型训练的后 25% 时间,我们可以在这 25% 时间里的每一轮训练(every epoch)后,计算一个权重的 running 平均值,在训练结束后再设置网络模型的权重为 SWA 权重平均值。

SWA 论文提供了对其算法的过程的描述,如下图所示:
SWA-Algorithm
上面算法描述,给出了下面几个重要的内容:

  • 学习率在模型训练过程中是不断调整的,同时更新 SGD 梯度值
  • 基于模型训练的迭代次数,设置在训练过程中不同的阶段(位置)更新两个不同的权重平均值
  • 在模型训练结束后,为 SWA 权重计算 BatchNorm 统计量

SWA 使用要点

在 PyTorch 中,使用 SWA 训练模型的基本要点,描述如下:

  • 创建 SWA 模型

创建 SWA 模型,直接使用 PyTorch 提供的 AveragedModel,传入上面创建的 model,其中 model 可以是任意继承了 torch.nn.Module 的模型实现:

swa_model = AveragedModel(model)

这样 swa_model 会在模型训练过程中持续跟踪 model 的参数的平均值。如果希望更新这个 running 平均值,需要在 optimizer.step() 之后调用 update_parameters() 函数:

swa_model.update_parameters(model)
  • 使用 SWALR

使用 SWA 时,通常配合使用 SWALR 这个 Learning Rate Scheduler 来退火(anneals)至一个固定的常量值,然后一直保持不变,使用示例:

swa_scheduler = torch.optim.swa_utils.SWALR(
    optimizer, anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05
)
# 或者
swa_scheduler = torch.optim.swa_utils.SWALR(
    optimizer, anneal_strategy="cos", anneal_epochs=5, swa_lr=0.05
)
  • 计算模型在 DataLoader 上的 BN 统计量

在模型训练结束之后,需要进行计算 BN(Batch Normalization)统计量并更新模型:

torch.optim.swa_utils.update_bn(train_dataloader, swa_model)

SWA 编程实践

下面,我们基于上面提到的使用 SWA 的方法训练模型,通过实际编程来加强理解:

1 准备数据集并定义模型

我们使用 MNIST 数据集,神经网络模型使用 LeNet-5,代码处理逻辑如下所示:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Download training/test data from open datasets.
train_dataset = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor(),)
test_dataset = datasets.MNIST(root="data", train=False, download=False, transform=ToTensor(),)
print(f"train_dataset_size = {len(train_dataset)}, test_dataset_size = {len(test_dataset)}")

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Define model
class LeNet5Model(nn.Module):
    def __init__(self):
        super().__init__()
        self._conv = nn.Sequential(
            nn.Conv2d(1, 6, 5, 1),
            nn.MaxPool2d(2),
            nn.Conv2d(6, 16, 5, 1),
            nn.MaxPool2d(2)
        )
        self._fc = nn.Sequential(
            nn.Linear(4*4*16, 120),
            nn.Linear(120, 84),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        x = self._conv(x)
        x = x.view(-1, 4 * 4 * 16)
        x = self._fc(x)
        return x

代码比较容易,不再赘述。

2 使用 SWA 训练模型

首先,我们创建模型,并指定要使用的 Loss Function 和 Optimizer,实现代码如下所示:

# Create model
model = LeNet5Model().to(device)
print(model)

# Define loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

使用 SWA,直接在以往的编程基础上,复用上面创建的一些对象即可。
下面,实现基于 SWA 训练模型的核心逻辑代码,如下所示:

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

def swa_train(epoch, train_loader, model, loss_fn, optimizer, swa_start):
    # scheduler = CosineAnnealingLR(optimizer, T_max=10)
    scheduler = SWALR(optimizer, anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)
    swa_scheduler = SWALR(optimizer, swa_lr=0.05)
    swa_model = AveragedModel(model)
    size = len(train_loader.dataset)
    for batch, (X, y) in enumerate(train_loader):
        optimizer.zero_grad()
        loss = loss_fn(model(X), y)
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                
        if epoch > swa_start:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            scheduler.step()
    return swa_model

在模型训练过程中,使用 swa_start 控制启动 SWA 更新模型参数,并使用 swa_scheduler 来调节学习率,否则就使用默认的学习率调节器控制训练过程。
现在,我们就可以调用上面提供的 SWA 模型训练函数 swa_train(),训练我们前面定义的 LeNet 模型,代码如下:

epochs = 10
swa_start = 7
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}\n-------------------------------")
    swa_model = swa_train(epoch, train_dataloader, model, loss_fn, optimizer, swa_start) 

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(train_dataloader, swa_model)
test(test_dataloader, swa_model, loss_fn)

在调用 swa_train() 函数训练的迭代过程中,并没有对模型进行预测调用,所以 Batch Norm 层也就没有计算过神经网络中这些激活统计量。在模型训练结束后,需要使用的训练数据对 swa_model 模型进行一次 forward 计算并更新这些统计量的值。从上面代码可以看到,BatchNorm 层在模型训练期结束后会调用函数 update_bn() 来计算激活统计量,并更新模型。
上面的 test() 函数,实现了使用测试集 test_dataloader 来计算 loss 值和测试精度,对应代码如下:

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

运行上面代码,可以执行模型训练,训练过程输出信息示例如下:

Epoch 1
-------------------------------
loss: 0.034567  [   64/60000]
loss: 0.062562  [ 6464/60000]
loss: 0.085734  [12864/60000]
loss: 0.029585  [19264/60000]
loss: 0.005262  [25664/60000]
loss: 0.067925  [32064/60000]
loss: 0.108129  [38464/60000]
loss: 0.027133  [44864/60000]
loss: 0.205612  [51264/60000]
loss: 0.077616  [57664/60000]
Epoch 2
-------------------------------
loss: 0.010665  [   64/60000]
loss: 0.066350  [ 6464/60000]
loss: 0.074224  [12864/60000]
loss: 0.042965  [19264/60000]
loss: 0.005810  [25664/60000]
loss: 0.064340  [32064/60000]
loss: 0.106495  [38464/60000]
loss: 0.023883  [44864/60000]
loss: 0.189915  [51264/60000]
loss: 0.081894  [57664/60000]
Epoch 3
... ...
Epoch 9
-------------------------------
loss: 0.010908  [   64/60000]
loss: 0.073716  [ 6464/60000]
loss: 0.044475  [12864/60000]
loss: 0.083066  [19264/60000]
loss: 0.019485  [25664/60000]
loss: 0.056779  [32064/60000]
loss: 0.118609  [38464/60000]
loss: 0.020461  [44864/60000]
loss: 0.145578  [51264/60000]
loss: 0.121390  [57664/60000]
Epoch 10
-------------------------------
loss: 0.029301  [   64/60000]
loss: 0.065231  [ 6464/60000]
loss: 0.043666  [12864/60000]
loss: 0.064136  [19264/60000]
loss: 0.015114  [25664/60000]
loss: 0.054391  [32064/60000]
loss: 0.124274  [38464/60000]
loss: 0.018437  [44864/60000]
loss: 0.135725  [51264/60000]
loss: 0.110678  [57664/60000]
Test Error: 
 Accuracy: 98.5%, Avg loss: 0.043567 

3 使用 SWA 模型

使用模型,示例代码如下所示:

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

swa_model.eval()
for i in range(10):
    start = 13 * i + 8
    x, y = test_dataset[start][0], test_dataset[start][1]
    with torch.no_grad():
        x = x.to(device)
        pred = swa_model(x)
        predicted, actual = classes[pred[0].argmax(0)], classes[y]
        print(f'Predicted: "{predicted:<12s}", Actual: "{actual:<12s}"')

可以看到,使用模型预测的结果示例:

Predicted: "Sandal      ", Actual: "Sandal      "
Predicted: "Shirt       ", Actual: "Shirt       "
Predicted: "Sneaker     ", Actual: "Sneaker     "
Predicted: "Pullover    ", Actual: "Pullover    "
Predicted: "Sneaker     ", Actual: "Sneaker     "
Predicted: "Ankle boot  ", Actual: "Ankle boot  "
Predicted: "Sneaker     ", Actual: "Sneaker     "
Predicted: "Ankle boot  ", Actual: "Ankle boot  "
Predicted: "Dress       ", Actual: "Dress       "
Predicted: "Ankle boot  ", Actual: "Ankle boot  "

总结

我们从使用 SWA 能为我们带来的优势出发,并结合 SWA 论文实验给定的一些结论,对 SWA 进行总结:

  • 使用 SWA 非常简单,有效提高模型训练性能

SWA 使用了一个可以在训练模型过程中不断修改的学习率调节器(Learning Rate Schedule),能够更快地到达最优解附近区域。而且,模型训练的时候,降低迭代次数(epochs)也能够很快得到比较高的精度。

  • 得到更合适的模型参数

在 SWA 论文中,使用预激活(Preactivation) ResNet-164 模型,在 CIFAR-100 数据集上分别使用 SWA 和 SGD 训练模型,得到结果如下图所示:
Weight-Comparison-SGD-vs-SWA
上图中表明,使用 SGD,训练模型得到的 Train Loss 是最优的,但是对应的 Test Error 并不是最优的,说明这个最优值可能是一个局部最优值;而使用 SWA 得到的 Train Loss 不是最优的,但是对应的 Test Error 却是最优的,说明 SWA 得到的 Train Loss 是一个比 SGD 更优的全局解。

  • 更好的泛化性能

从泛化性能方面来考虑,对比 SGD 和 SWA,如下图所示:
SWA-Generalization
使用 SGD,得到的 Train Loss 是一个最优值,但是位于边界位置上,所以对应的 Test Error 变化幅度就会相对更大,从图中可以看到 Train Loss 是最优值但对应的 Test Error 并不是。
使用 SWA,能够得到一个解,它使 Train Loss 集中在一个足够宽的平滑区域范围内,而在这个区域内得到的对应 Test Error 是全局最优解,这样就能更有潜力使获得的模型具有更好的泛化性能。

  • 使用 SWA 应用范围比较广泛

SWA 不仅可以用于 SGD 优化器,还可以用于其他一些优化器,比如 Adam。
SWA 还有其他一些扩展,如 SWAG、MultiSWAG、SWALP、SWAP,在对应的应用领域内,都能够得到比较好的效果。
另外,对于基于大模型的预训练场景,鉴于 SWA 的加快模型训练速度的优势,也能够更加广泛地被应用。

参考资源

Creative Commons License

本文基于署名-非商业性使用-相同方式共享 4.0许可协议发布,欢迎转载、使用、重新发布,但务必保留文章署名时延军(包含链接:http://shiyanjun.cn),不得用于商业目的,基于本文修改后的作品务必以相同的许可发布。如有任何疑问,请与我联系

发表评论

电子邮件地址不会被公开。 必填项已用*标注

您可以使用这些HTML标签和属性: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>