使用预训练模型有两种方式:一种是直接使用得到的预训练模型进行推理,并满足应用的需要,使用起来非常简单;另一种是在预训练模型的基础上,进行微调,使得到的新模型能够更好地满足我们解决问题的需要,这种方式需要能够对模型进行调优有一定门槛。这里,我们尝试第一种方式直接使用预训练模型,着重关注使用预训练模型处理图片分类的过程,从而熟悉在实际应用中都需要做哪些处理工作。
预训练模型
预训练模型(Pre-trained Models,PTMs)是一种深度学习架构,它在大规模数据集上进行训练,以获取丰富的特征表示。训练得到的模型可以进行复用,不仅能够适用于最初要解决的问题,还可以迁移到其他类似的应用场景中,从而提高在这些新领域的应用的性能。
预训练模型通常具有较大的参数规模,需要使用海量的数据和高昂的计算资源代价,才能完成模型训练并最终得到模型参数,这对于一些不具备基于超大规模数据训练能力的使用者来说,就无法发挥模型的作用,而且也不能很方便地在特定应用领域内探索并验证一些应用的想法。
例如,在 NLP 领域,预训练模型应用的特别广泛,因为它们可以从海量的文本数据中学习到有用的语义信息。而从头开始训练这些 NLP 模型需要大量的计算资源,这对于基于此类模型的下游应用场景几乎是不可能的,如解决诸如语言理解、机器翻译、自动问答等问题都受到了极大的限制。所以使用预训练模型,可以极大地降低下游应用场景使用的代价和复杂度,而把精力聚焦在特定的场景的问题上。通过直接使用预训练模型,或者进行简单的微调就能够很好地完成下游的一些任务,像文本分类、序列标注和阅读理解等,从而实现性能的提升。
在 PyTorch 中内置了很多预训练模型,我们可以直接通过 torchvision.models 提供的 API 来使用。查看当前模型库里面有哪些预训练模型:
from torchvision import models dir(models)
可以看到,有很多可以使用的经典神经网络的预训练模型,如 AleNet、ResNet、GoogLeNet、VGG 等,示例如下所示:
['AlexNet', 'AlexNet_Weights', 'ConvNeXt', 'ConvNeXt_Base_Weights', 'ConvNeXt_Large_Weights', 'ConvNeXt_Small_Weights', 'ConvNeXt_Tiny_Weights', 'DenseNet', ... ... 'ResNet', 'ResNet101_Weights', 'ResNet152_Weights', 'ResNet18_Weights', 'ResNet34_Weights', 'ResNet50_Weights', 'ShuffleNetV2', 'ShuffleNet_V2_X0_5_Weights', 'ShuffleNet_V2_X1_0_Weights', 'ShuffleNet_V2_X1_5_Weights', 'ShuffleNet_V2_X2_0_Weights', 'SqueezeNet', ... ... 'vgg', 'vgg11', 'vgg11_bn', 'vgg13', ... ...]
根据我们实际的资源和应用需求,可以选择对应的预训练模型来实现推理功能。
下面,我们使用具有 18 层深度的 ResNet 预训练神经网络模型,来说明如何对指定的任意图片进行分类。
使用预训练模型 ResNet-18 分类图片
下面我们把使用预训练模型的过程,分为 4 个步骤进行操作:
1 获取 ImageNet 数据集 label
ResNet 基于 ImageNet 数据集进行训练,所以 label 也是来自 ImageNet。可以从网上下载对应的 label 文件,我找到的是 caffe_ilsvrc12.tar.gz,解压缩后可以得到一个 synset_words.txt 文件,里面是关于图片的 1000 个 label,内容示例:
n01440764 tench, Tinca tinca n01443537 goldfish, Carassius auratus n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias n01491361 tiger shark, Galeocerdo cuvieri n01494475 hammerhead, hammerhead shark n01496331 electric ray, crampfish, numbfish, torpedo
可以直接处理并提取文件中的 label 内容,也就是除了第一列以外,剩下的其它列的内容是一个包含多个词的 label:
with open("./synset_words.txt") as f: classes = [line.split(" ")[1:] for line in f.readlines()] print(len(classes))
把 label 名称直接加载到 classes 数组中,后面使用预训练模型推理后,需要找到对应的 label 名称。
2 加载 ResNet-18 预训练模型
直接使用 models.resnet18,会下载模型参数,并加载到内存:
from torchvision import models resnet = models.resnet18(pretrained=True) print(resnet)
可以看到 ResNet-18 模型的结构,如下所示:
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=1000, bias=True) )
可以看到,ResNet-18 网络模型的结构包含了哪些层,以及对应参数情况。
另外,也可以使用 PyTorch Hub 提供的 API 来直接加载对应的预训练模型,下载后继续使用模型。PyTorch Hub 提供了几种方式,我们通过代码片段简单说明如下,不过多实践了:
- 使用 torch.hub.load() 获取预训练模型
model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
- 使用 torch.hub.load_state_dict_from_url() 获取预训练模型
state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
- 使用 torch.hub.download_url_to_file() 获取预训练模型
torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
3 预处理待分类图片
我们随便找了一张带兔子的图片 ./myimages/rabbit.jpeg,你也可以拿其他图片测试:
然后对输入图片进行处理,得到满足 ResNet-18 模型推理输入要求的 Tensor,如下所示:
# Define preprocessing transformers chain preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = Image.open("./myimages/rabbit.jpeg") img = preprocess(img) input_tensor = torch.unsqueeze(img, 0)
使用上面得到的 input_tensor 就可以使用模型进行推理。
4 推理并得到图片 label
resnet.eval() output = resnet(input_tensor) _, indices = torch.sort(output, descending=True) [(I, classes[i], percentage[i].item()) for i in indices[0][:3]]
上面代码通过 output = resnet(input_tensor) 进行推理,得到一个 Shape 是 torch.Size([1, 1000]) 的结果 Tensor,这里面并没有直接给出分类的 label,我们需要处理一下:对其进行降序排序,并取出 Top 3 最大的分值,并转换成百分比,表示输入图片属于某一个 label 的概率;然后,计算得到这 3 个分值对应索引位置;最后,根据索引位置从 classes 数组得到对应的 label 名称。
运行代码,输出结果如下:
[(tensor(331), ['hare\n'], 71.56768035888672), (tensor(330), ['wood', 'rabbit,', 'cottontail,', 'cottontail', 'rabbit\n'], 27.239879608154297), (tensor(332), ['Angora,', 'Angora', 'rabbit\n'], 0.4852880835533142)]
通过结果可以看到,我们输入图片经过模型推理,得到的第一个推理结果分值最高, label 名称是 hare,“野兔”的意思,白兔的图片确实和 hare 这个 label 更加接近。 其它的两个 label 中,可能是在预训练模型训练中,数据集的图片里有其他更多内容,这里经过推理得到的分值并不高。
参考资源
- https://pytorch.org/hub/
- https://pytorch.org/docs/stable/hub.html
- Deep Learning with PyTorch
本文基于署名-非商业性使用-相同方式共享 4.0许可协议发布,欢迎转载、使用、重新发布,但务必保留文章署名时延军(包含链接:http://shiyanjun.cn),不得用于商业目的,基于本文修改后的作品务必以相同的许可发布。如有任何疑问,请与我联系。