基于 PyTorch 编程使用预训练模型

使用预训练模型有两种方式:一种是直接使用得到的预训练模型进行推理,并满足应用的需要,使用起来非常简单;另一种是在预训练模型的基础上,进行微调,使得到的新模型能够更好地满足我们解决问题的需要,这种方式需要能够对模型进行调优有一定门槛。这里,我们尝试第一种方式直接使用预训练模型,着重关注使用预训练模型处理图片分类的过程,从而熟悉在实际应用中都需要做哪些处理工作。

预训练模型

预训练模型(Pre-trained Models,PTMs)是一种深度学习架构,它在大规模数据集上进行训练,以获取丰富的特征表示。训练得到的模型可以进行复用,不仅能够适用于最初要解决的问题,还可以迁移到其他类似的应用场景中,从而提高在这些新领域的应用的性能。
预训练模型通常具有较大的参数规模,需要使用海量的数据和高昂的计算资源代价,才能完成模型训练并最终得到模型参数,这对于一些不具备基于超大规模数据训练能力的使用者来说,就无法发挥模型的作用,而且也不能很方便地在特定应用领域内探索并验证一些应用的想法。
例如,在 NLP 领域,预训练模型应用的特别广泛,因为它们可以从海量的文本数据中学习到有用的语义信息。而从头开始训练这些 NLP 模型需要大量的计算资源,这对于基于此类模型的下游应用场景几乎是不可能的,如解决诸如语言理解、机器翻译、自动问答等问题都受到了极大的限制。所以使用预训练模型,可以极大地降低下游应用场景使用的代价和复杂度,而把精力聚焦在特定的场景的问题上。通过直接使用预训练模型,或者进行简单的微调就能够很好地完成下游的一些任务,像文本分类、序列标注和阅读理解等,从而实现性能的提升。
在 PyTorch 中内置了很多预训练模型,我们可以直接通过 torchvision.models 提供的 API 来使用。查看当前模型库里面有哪些预训练模型:

from torchvision import models
dir(models)

可以看到,有很多可以使用的经典神经网络的预训练模型,如 AleNet、ResNet、GoogLeNet、VGG 等,示例如下所示:

['AlexNet',
 'AlexNet_Weights',
 'ConvNeXt',
 'ConvNeXt_Base_Weights',
 'ConvNeXt_Large_Weights',
 'ConvNeXt_Small_Weights',
 'ConvNeXt_Tiny_Weights',
 'DenseNet',
 ... ...
 'ResNet',
 'ResNet101_Weights',
 'ResNet152_Weights',
 'ResNet18_Weights',
 'ResNet34_Weights',
 'ResNet50_Weights',
 'ShuffleNetV2',
 'ShuffleNet_V2_X0_5_Weights',
 'ShuffleNet_V2_X1_0_Weights',
 'ShuffleNet_V2_X1_5_Weights',
 'ShuffleNet_V2_X2_0_Weights',
 'SqueezeNet',
 ... ...
 'vgg',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 ... ...]

根据我们实际的资源和应用需求,可以选择对应的预训练模型来实现推理功能。
下面,我们使用具有 18 层深度的 ResNet 预训练神经网络模型,来说明如何对指定的任意图片进行分类。

使用预训练模型 ResNet-18 分类图片

下面我们把使用预训练模型的过程,分为 4 个步骤进行操作:

1 获取 ImageNet 数据集 label

ResNet 基于 ImageNet 数据集进行训练,所以 label 也是来自 ImageNet。可以从网上下载对应的 label 文件,我找到的是 caffe_ilsvrc12.tar.gz,解压缩后可以得到一个 synset_words.txt 文件,里面是关于图片的 1000 个 label,内容示例:

n01440764 tench, Tinca tinca
n01443537 goldfish, Carassius auratus
n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
n01491361 tiger shark, Galeocerdo cuvieri
n01494475 hammerhead, hammerhead shark
n01496331 electric ray, crampfish, numbfish, torpedo

可以直接处理并提取文件中的 label 内容,也就是除了第一列以外,剩下的其它列的内容是一个包含多个词的 label:

with open("./synset_words.txt") as f:
        classes = [line.split(" ")[1:] for line in f.readlines()]
print(len(classes))

把 label 名称直接加载到 classes 数组中,后面使用预训练模型推理后,需要找到对应的 label 名称。

2 加载 ResNet-18 预训练模型

直接使用 models.resnet18,会下载模型参数,并加载到内存:

from torchvision import models
resnet = models.resnet18(pretrained=True)
print(resnet)

可以看到 ResNet-18 模型的结构,如下所示:

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

可以看到,ResNet-18 网络模型的结构包含了哪些层,以及对应参数情况。
另外,也可以使用 PyTorch Hub 提供的 API 来直接加载对应的预训练模型,下载后继续使用模型。PyTorch Hub 提供了几种方式,我们通过代码片段简单说明如下,不过多实践了:

  • 使用 torch.hub.load() 获取预训练模型
model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
  • 使用 torch.hub.load_state_dict_from_url() 获取预训练模型
state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
  • 使用 torch.hub.download_url_to_file() 获取预训练模型
torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')

3 预处理待分类图片

我们随便找了一张带兔子的图片 ./myimages/rabbit.jpeg,你也可以拿其他图片测试:
rabbit
然后对输入图片进行处理,得到满足 ResNet-18 模型推理输入要求的 Tensor,如下所示:

# Define preprocessing transformers chain
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

img = Image.open("./myimages/rabbit.jpeg")
img = preprocess(img)
input_tensor = torch.unsqueeze(img, 0)

使用上面得到的 input_tensor 就可以使用模型进行推理。

4 推理并得到图片 label

resnet.eval()
output = resnet(input_tensor)
_, indices = torch.sort(output, descending=True)
[(I, classes[i], percentage[i].item()) for i in indices[0][:3]]

上面代码通过 output = resnet(input_tensor) 进行推理,得到一个 Shape 是 torch.Size([1, 1000]) 的结果 Tensor,这里面并没有直接给出分类的 label,我们需要处理一下:对其进行降序排序,并取出 Top 3 最大的分值,并转换成百分比,表示输入图片属于某一个 label 的概率;然后,计算得到这 3 个分值对应索引位置;最后,根据索引位置从 classes 数组得到对应的 label 名称。
运行代码,输出结果如下:

[(tensor(331), ['hare\n'], 71.56768035888672),
 (tensor(330),
  ['wood', 'rabbit,', 'cottontail,', 'cottontail', 'rabbit\n'],
  27.239879608154297),
 (tensor(332), ['Angora,', 'Angora', 'rabbit\n'], 0.4852880835533142)]

通过结果可以看到,我们输入图片经过模型推理,得到的第一个推理结果分值最高, label 名称是 hare,“野兔”的意思,白兔的图片确实和 hare 这个 label 更加接近。 其它的两个 label 中,可能是在预训练模型训练中,数据集的图片里有其他更多内容,这里经过推理得到的分值并不高。

参考资源

Creative Commons License

本文基于署名-非商业性使用-相同方式共享 4.0许可协议发布,欢迎转载、使用、重新发布,但务必保留文章署名时延军(包含链接:http://shiyanjun.cn),不得用于商业目的,基于本文修改后的作品务必以相同的许可发布。如有任何疑问,请与我联系

发表评论

电子邮件地址不会被公开。 必填项已用*标注

您可以使用这些HTML标签和属性: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>