侧边栏壁纸
  • 累计撰写 55 篇文章
  • 累计创建 34 个标签
  • 累计收到 7 条评论

目 录CONTENT

文章目录

PyTorch 学习笔记 第5课 CNN模型的迁移学习

NormanZhu
2020-10-20 / 0 评论 / 0 点赞 / 477 阅读 / 8,662 字
温馨提示:
本文最后更新于 2020-10-20,若内容或图片失效,请留言反馈。部分素材来自网络,若不小心影响到您的利益,请联系我们删除。

这节课涉及到的知识比较广,而我跟随视频学习后,也并没有充足的精力将背后的技巧和其他知识一并学习,所以文中肯定有不正确的表述,还请指正。

简介

  • 很多时候我们需要训练一个新的图像分类任务,我们不会完全从一个随机的模型开始训练,而是利用预训练过的模型继续训练,这就是transfer learning 的方法。

  • 我们常用以下两种方法做迁移学习:

    • fine tuning:从一个预训练模型开始,改变一些模型的架构,然后继续训练整个模型的参数
    • feature extraction:我们不改变预训练模型的参数,而是只更新我们改变过的那部分模型参数
  • 构建和训练迁移学习模型的基本步骤:

    1. 初始化预训练模型
    2. 把最后一层的输出层改编成我们想要的分类总数
    3. 定义一个optimizer来更新参数
    4. 模型训练

背景

使用一个数据集,包括蜜蜂和蚂蚁的图片,我们的任务是训练一个模型能够将一个图片分类成蜜蜂或蚂蚁。

在这个任务中会使用到torchvision,它包含了许多流行的数据集、模型和图像转换的工具 。

准备工作

# 引入
import torch
import numpy as np
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms, models

import matplotlib.pyplot as plt
import time
import os
import copy
print("Torchvision Version: ", torchvision.__version__)
data_dir = './hymenoptera_data'	# 数据集的路径
model_name = 'resnet'
num_classes = 2		# 分类的数量
batch_size = 32	
num_epochs = 15		
feature_extract = True	# 使用feature extract

input_size = 224	# 这里是224是因为将要把图片裁切成224*224

读入数据

此部分需要使用到torchvisiondatasetstransforms

其中transforms包含了很多常用的图像转换工具,也可以被组织成一个链的形状,需要用到transforms.Compose([xxx, xxx]),详情可以查看

简单介绍将要使用到的几个Transform:

  • RandomResizedCrop(*size*, *scale=(0.08*, *1.0)*, *ratio=(0.75*, *1.3333333333333333)*, *interpolation=2*)

    • 将一个图像随机裁切到指定的大小
  • ``RandomHorizontalFlip(*p=0.5*)

    • 用于将图片水平翻转,增加数据集的噪声,提高训练效果
    • p是概率,这张图片有多大的概率被翻转
  • ToTensor()

    • 将图像转换为Tensor
  • Transforms

# 训练和测试的transformer是不一样的
data_transforms = {
    'train_data': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),  # 更加noisy
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
    'val_data': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.CenterCrop(input_size),  # 更加noisy
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
}

上面的Normanlize的参数为什么是那样我也没弄明白,也许是老师提前计算过了的吧。

  • Datasets

定义完了Transform,接下来是初始化数据集。这里用了点Python的技巧,实际上做的就是构造了一个字典,key分别为train_dataval_data,这样以后需要使用train或val的数据集的时候就能够很方便的从一个变量里面获取。

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train_data', 'val_data']}
  • DataLoaders

这里也用了类似上面的技巧。

dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train_data', 'val_data']}

DataLoader 有什么用?看看官网的介绍,就可以知道它实现了很多便于我们对数据集进行操作的函数和功能,比如迭代、自动分batch等:

dataloader

  • Device

别忘了,如果有块GPU,记得把device定义成cuda

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

展示数据

这一段代码直接使用就好了

unloader = transforms.ToPILImage()
plt.ion()

def imshow(tensor, title=None):
    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)    # pause a bit so that plots are updated
    
plt.figure()
imshow(img[31], title='Image')

运行后可以看到一张图片(可能会不太一样,因为经过了随机裁切)

example31

模型初始化

这里将使用到torchvision的models中的resnetxxx,比如下面使用的resnet18,意味着是一个有18层的resnet,同时,我们也传入use_pretrained参数,表示我们是否需要使用预训练的模型。

set_parameter_requies_grad()函数的作用是:如果不需要fine tune,则把requires_grad设为False(就不更新)

这个初始化模型的函数能够扩展到很多的使用场景(有封装的思想),只需要进一步扩展这个函数的实现即可,这里只实现了一个resnet的情况。

def set_parameter_required_grad(model, feature_extract):
    if feature_extract:
        for param in model.parameters():
            param.required_grad = False
            
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    if model_name == 'resnet':
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_required_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features  # num of features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224
    else:
        print('model not implemented')
        return None, None

    return model_ft, input_size

下面初始化模型:

model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
print(model_ft)

print一下可以查看模型的具体结构

模型训练

  • 定义训练和验证的通用函数
def train_model(model, dataloaders, loss_fn, optimizer, num_epochs=5):
    best_acc = 0.
    best_model_wts = copy.deepcopy(model.state_dict())
    val_acc_history = []
    for epoch in range(num_epochs):
        for phase in ['train', 'val']:
            running_loss = 0.
            running_corrects = 0.

            if phase == 'train':
                model.train()
            else:
                model.eval()

            for inputs, labels in dataloaders[phase + '_data']:
                inputs, labels = inputs.to(device), labels.to(device)

                with torch.autograd.set_grad_enabled(phase=='train'):
                    outputs = model(inputs)
                    loss = loss_fn(outputs, labels)

                preds = outputs.argmax(dim=1)
                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds.view(-1) == labels.view(-1)).item()

            epoch_loss = running_loss / len(dataloaders[phase + '_data'].dataset)
            epoch_acc = running_corrects / len(dataloaders[phase + '_data'].dataset)

            print('Phase {} loss: {}, acc: {}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

            if phase == 'val':
                val_acc_history.append(epoch_acc)

    model.load_state_dict(best_model_wts)

    return model, val_acc_history

解释一下这部分代码。首先说明这个函数的作用是训练模型,同时进行验证,然后每隔一个epoch就打印一次误差的信息。因此我们先定义好最佳的准确率best_accbest_model_wts,在获得下一次最佳准确率的时候进行一次模型的保存。保存val_acc_history是为了后面的可视化的(观察误差的变化情况)。

然后开始遍历(第5行),我们将分为两个阶段,trainval分别对应训练和验证,在各自的阶段调用model.train()model.eval()在各自两个阶段要做的事情跟以前的课程内容相同,这里面只是使用了if进行判断,好比验证阶段需要在with torch.no_grad()中进行等等。

完成迭代后,模型读取刚刚保存的最佳参数然后返回模型和准确率的变化情况。

另外,补充一下view(-1)的意思:view就是类似reshape,其中传入的参数在为-1的时候表示这个维度的大小由其他维度推算而得,如果只有一个-1就意味着转成一维的Tensor。

  • 训练过程
model_ft = model_ft.to(device)
# 补充:下面这个 filter 和lambda 做的事情是将参数中所有requires_grad的参数拿出来优化,其他的不动
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model_ft.parameters()),
                            lr=0.001, momentum=0.9)

loss_fn = nn.CrossEntropyLoss()

train_results = train_model(model_ft, dataloaders_dict, loss_fn, optimizer, num_epochs=num_epochs)

这样模型就完成了训练。但是这里面是使用了预训练的模型继续训练,并非一个完全随机的作为初始状态,所以一开始的准确率就已经达到90%了,下面是一部分我在训练的时候的输出:

Phase train loss: 0.783042344890657, acc: 0.4672131147540984
Phase val loss: 0.5575160146538728, acc: 0.673202614379085
Phase train loss: 0.5613641787747867, acc: 0.6680327868852459
Phase val loss: 0.3377787022808798, acc: 0.9019607843137255
Phase train loss: 0.3378620435957049, acc: 0.8647540983606558
Phase val loss: 0.25312920518560345, acc: 0.9150326797385621
Phase train loss: 0.25712054137323725, acc: 0.9016393442622951

如果不用预训练的模型和feature extract

只用把之前初始化模型的函数调用改成两个False就可以了

model_scratch, _ = initialize_model(model_name, num_classes, feature_extract=False, use_pretrained=False)
model_scratch = model_scratch.to(device)
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model_ft.parameters()),
                            lr=0.01, momentum=0.9)
loss_fn = nn.CrossEntropyLoss()
_, scratch_hist =  train_model(model_scratch, dataloaders_dict, loss_fn, optimizer, num_epochs=num_epochs)

这里训练出来的效果就明显不那么好了,一会儿在可视化中对比。

可视化

plt.title('Validation Accuracy vs. Number of Training Epochs')
plt.xlabel('Training Epochs')
plt.ylabel('Validation Accuracy')
plt.plot(range(1, num_epochs + 1), train_results[1], label='Pretrained')
plt.plot(range(1, num_epochs + 1), scratch_hist, label='Scratch')
plt.ylim((0, 1.))
plt.xticks(np.arange(1, num_epochs + 1, 1.0))
plt.legend()
plt.show()

visualization

0

评论区