- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊 | 接辅导、项目定制
目录
- 环境
- 步骤
- 环境设置
- 数据准备
- 图像信息查看
- 模型设计
- 模型训练
- 模型效果展示
- 总结与心得体会
环境
- 系统: Linux
- 语言: Python3.8.10
- 深度学习框架: Pytorch2.0.0+cu118
步骤
环境设置
包引用
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transformsimport copy, random, pathlib
import matplotlib.pyplot as plt
from PIL import Image
from torchinfo import summary
import numpy as np
设置一个全局的设备,使后面的模型和数据放置在统一的设备中
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
数据准备
从K同学提供的网盘中下载鸟类数据集,解压到data目录下,数据集的结构如下:
其中bird_photos下不同的文件夹中保存了不同类型的鸟类图像,这个目录结构可以使用torchvision.datasets.ImageFolder直接加载
图像信息查看
- 获取到所有的图像
root_dir = 'data/bird_photos'
root_directory = pathlib.Path(root_dir)
image_list = root_directory.glob("*/*")
- 随机打印5个图像的尺寸
for _ in range(5):print(np.array(Image.open(str(random.choice(image_list)))).shape)
发现都是224*224大小的三通道图像,所以我们可以在数据集处理时省略Resize这一步,或者加上224的Resize排除异常情况
3. 随机打印20个图像
plt.figure(figsize=(20, 4))
for i in range(20):plt.subplot(2, 10, i+1)plt.axis('off')image = random.choice(image_list)class_name = image.parts[-2]plt.title(class_name)plt.imshow(Image.open(str(image)))
4. 创建数据集
首先定义一个图像的预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],),
])
然后通过datasets.ImageFolder
加载文件夹
dataset = datasets.ImageFolder(root_dir, transform=transform)
从数据中提取图像不同的分类名称
class_names = [x for x in dataset.class_to_idx]
划分训练集和验证集
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_sizetrain_dataset, test_dataset = random_split(dataset, [train_size, test_size])
最后,将数据集划分批次
batch_size = 8
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
模型设计
通过对K同学tensorflow代码的分析,发现隐藏层的始终只有一种通道数,因此把原来的filters参数修改为hidden_channel,并且pytorch不能自动计算前后的通道数,需要额外增加一个参数传入。
- 编写IdentityBlock
IdentityBlock是一个恒等块,出入通道数一致。
class IdentityBlock(nn.Module):def __init__(self, kernel_size, input_channel, hidden_channel):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(input_channel, hidden_channel, 1), nn.BatchNorm2d(hidden_channel), nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(hidden_channel, hidden_channel, 3, padding='same'), nn.BatchNorm2d(hidden_channel), nn.ReLU())self.conv3 = nn.Sequential(nn.Conv2d(hidden_channel, input_channel, 1), nn.BatchNorm2d(input_channel))self.relu = nn.ReLU()def forward(self, inputs):x = self.conv1(inputs)x = self.conv2(x)x = self.conv3(x)x = inputs + xx = self.relu(x)return x
可以生成一个输入测试一下模型执行过程中的shape
inputs = torch.zeros((8, 256, 224, 224))
model = IdentityBlock(3, 256, 64)
outputs = model(inputs)
print(inputs.shape)
print(outputs.shape)
3. 编写ConvBlock
ConvBlock不是一个恒等块,出入通道数不一致。
class ConvBlock(nn.Module):def __init__(self, kernel_size, input_channel, output_channel, hidden_channel, stride=2):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(input_channel, hidden_channel, 1, stride=stride), nn.BatchNorm2d(hidden_channel), nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(hidden_channel, hidden_channel, kernel_size, padding='same'), nn.BatchNorm2d(hidden_channel), nn.ReLU())self.conv3 = nn.Sequential(nn.Conv2d(hidden_channel, output_channel, 1), nn.BatchNorm2d(output_channel))self.shortcut = nn.Sequential(nn.Conv2d(input_channel, output_channel,1, stride=stride), nn.BatchNorm2d(output_channel))self.relu = nn.ReLU()def forward(self, inputs):x = self.conv1(inputs)x = self.conv2(x)x = self.conv3(x)x = x + self.shortcut(inputs)x = self.relu(x)return x
生成假数据测试块代码
inputs = torch.zeros((8, 256, 224, 224))
model = ConvBlock(3, 256, 512, 256, 2)
outputs = model(inputs)
print(inputs.shape)
print(outputs.shape)
5. 编写ResNet-50
class ResNet50(nn.Module):def __init__(self, input_channel, classes=1000):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(input_channel, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(3, 2))self.block1 = nn.Sequential(ConvBlock(3, 64, 256, 64, 1),IdentityBlock(3, 256, 64),IdentityBlock(3, 256, 64))self.block2 = nn.Sequential(ConvBlock(3, 256, 512, 128),IdentityBlock(3, 512, 128),IdentityBlock(3, 512, 128),IdentityBlock(3, 512, 128))self.block3 = nn.Sequential(ConvBlock(3, 512, 1024, 256),IdentityBlock(3, 1024, 256),IdentityBlock(3, 1024, 256),IdentityBlock(3, 1024, 256),IdentityBlock(3, 1024, 256),IdentityBlock(3, 1024, 256))self.block4 = nn.Sequential(ConvBlock(3, 1024, 2048, 512),IdentityBlock(3, 2048, 512),IdentityBlock(3, 2048, 512),)self.avgpool = nn.AdaptiveAvgPool2d(7)self.fc = nn.Linear(7*7*2048, classes)self.softmax = nn.Softmax(dim = 1)def forward(self, inputs):x = self.conv1(inputs)x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)x = self.fc(x)x = self.softmax(x)return x
使用torchinfo库打印一下模型的参数结构等信息
model = ResNet50(3, len(class_names)).to(device)
summary(model, input_size=(8, 3, 224, 224))
打印结果如下:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
ResNet50 [8, 4] --
├─Sequential: 1-1 [8, 64, 55, 55] --
│ └─Conv2d: 2-1 [8, 64, 112, 112] 9,472
│ └─BatchNorm2d: 2-2 [8, 64, 112, 112] 128
│ └─ReLU: 2-3 [8, 64, 112, 112] --
│ └─MaxPool2d: 2-4 [8, 64, 55, 55] --
├─Sequential: 1-2 [8, 256, 55, 55] --
│ └─ConvBlock: 2-5 [8, 256, 55, 55] --
│ │ └─Sequential: 3-1 [8, 64, 55, 55] 4,288
│ │ └─Sequential: 3-2 [8, 64, 55, 55] 37,056
│ │ └─Sequential: 3-3 [8, 256, 55, 55] 17,152
│ │ └─Sequential: 3-4 [8, 256, 55, 55] 17,152
│ │ └─ReLU: 3-5 [8, 256, 55, 55] --
│ └─IdentityBlock: 2-6 [8, 256, 55, 55] --
│ │ └─Sequential: 3-6 [8, 64, 55, 55] 16,576
│ │ └─Sequential: 3-7 [8, 64, 55, 55] 37,056
│ │ └─Sequential: 3-8 [8, 256, 55, 55] 17,152
│ │ └─ReLU: 3-9 [8, 256, 55, 55] --
│ └─IdentityBlock: 2-7 [8, 256, 55, 55] --
│ │ └─Sequential: 3-10 [8, 64, 55, 55] 16,576
│ │ └─Sequential: 3-11 [8, 64, 55, 55] 37,056
│ │ └─Sequential: 3-12 [8, 256, 55, 55] 17,152
│ │ └─ReLU: 3-13 [8, 256, 55, 55] --
├─Sequential: 1-3 [8, 512, 28, 28] --
│ └─ConvBlock: 2-8 [8, 512, 28, 28] --
│ │ └─Sequential: 3-14 [8, 128, 28, 28] 33,152
│ │ └─Sequential: 3-15 [8, 128, 28, 28] 147,840
│ │ └─Sequential: 3-16 [8, 512, 28, 28] 67,072
│ │ └─Sequential: 3-17 [8, 512, 28, 28] 132,608
│ │ └─ReLU: 3-18 [8, 512, 28, 28] --
│ └─IdentityBlock: 2-9 [8, 512, 28, 28] --
│ │ └─Sequential: 3-19 [8, 128, 28, 28] 65,920
│ │ └─Sequential: 3-20 [8, 128, 28, 28] 147,840
│ │ └─Sequential: 3-21 [8, 512, 28, 28] 67,072
│ │ └─ReLU: 3-22 [8, 512, 28, 28] --
│ └─IdentityBlock: 2-10 [8, 512, 28, 28] --
│ │ └─Sequential: 3-23 [8, 128, 28, 28] 65,920
│ │ └─Sequential: 3-24 [8, 128, 28, 28] 147,840
│ │ └─Sequential: 3-25 [8, 512, 28, 28] 67,072
│ │ └─ReLU: 3-26 [8, 512, 28, 28] --
│ └─IdentityBlock: 2-11 [8, 512, 28, 28] --
│ │ └─Sequential: 3-27 [8, 128, 28, 28] 65,920
│ │ └─Sequential: 3-28 [8, 128, 28, 28] 147,840
│ │ └─Sequential: 3-29 [8, 512, 28, 28] 67,072
│ │ └─ReLU: 3-30 [8, 512, 28, 28] --
├─Sequential: 1-4 [8, 1024, 14, 14] --
│ └─ConvBlock: 2-12 [8, 1024, 14, 14] --
│ │ └─Sequential: 3-31 [8, 256, 14, 14] 131,840
│ │ └─Sequential: 3-32 [8, 256, 14, 14] 590,592
│ │ └─Sequential: 3-33 [8, 1024, 14, 14] 265,216
│ │ └─Sequential: 3-34 [8, 1024, 14, 14] 527,360
│ │ └─ReLU: 3-35 [8, 1024, 14, 14] --
│ └─IdentityBlock: 2-13 [8, 1024, 14, 14] --
│ │ └─Sequential: 3-36 [8, 256, 14, 14] 262,912
│ │ └─Sequential: 3-37 [8, 256, 14, 14] 590,592
│ │ └─Sequential: 3-38 [8, 1024, 14, 14] 265,216
│ │ └─ReLU: 3-39 [8, 1024, 14, 14] --
│ └─IdentityBlock: 2-14 [8, 1024, 14, 14] --
│ │ └─Sequential: 3-40 [8, 256, 14, 14] 262,912
│ │ └─Sequential: 3-41 [8, 256, 14, 14] 590,592
│ │ └─Sequential: 3-42 [8, 1024, 14, 14] 265,216
│ │ └─ReLU: 3-43 [8, 1024, 14, 14] --
│ └─IdentityBlock: 2-15 [8, 1024, 14, 14] --
│ │ └─Sequential: 3-44 [8, 256, 14, 14] 262,912
│ │ └─Sequential: 3-45 [8, 256, 14, 14] 590,592
│ │ └─Sequential: 3-46 [8, 1024, 14, 14] 265,216
│ │ └─ReLU: 3-47 [8, 1024, 14, 14] --
│ └─IdentityBlock: 2-16 [8, 1024, 14, 14] --
│ │ └─Sequential: 3-48 [8, 256, 14, 14] 262,912
│ │ └─Sequential: 3-49 [8, 256, 14, 14] 590,592
│ │ └─Sequential: 3-50 [8, 1024, 14, 14] 265,216
│ │ └─ReLU: 3-51 [8, 1024, 14, 14] --
│ └─IdentityBlock: 2-17 [8, 1024, 14, 14] --
│ │ └─Sequential: 3-52 [8, 256, 14, 14] 262,912
│ │ └─Sequential: 3-53 [8, 256, 14, 14] 590,592
│ │ └─Sequential: 3-54 [8, 1024, 14, 14] 265,216
│ │ └─ReLU: 3-55 [8, 1024, 14, 14] --
├─Sequential: 1-5 [8, 2048, 7, 7] --
│ └─ConvBlock: 2-18 [8, 2048, 7, 7] --
│ │ └─Sequential: 3-56 [8, 512, 7, 7] 525,824
│ │ └─Sequential: 3-57 [8, 512, 7, 7] 2,360,832
│ │ └─Sequential: 3-58 [8, 2048, 7, 7] 1,054,720
│ │ └─Sequential: 3-59 [8, 2048, 7, 7] 2,103,296
│ │ └─ReLU: 3-60 [8, 2048, 7, 7] --
│ └─IdentityBlock: 2-19 [8, 2048, 7, 7] --
│ │ └─Sequential: 3-61 [8, 512, 7, 7] 1,050,112
│ │ └─Sequential: 3-62 [8, 512, 7, 7] 2,360,832
│ │ └─Sequential: 3-63 [8, 2048, 7, 7] 1,054,720
│ │ └─ReLU: 3-64 [8, 2048, 7, 7] --
│ └─IdentityBlock: 2-20 [8, 2048, 7, 7] --
│ │ └─Sequential: 3-65 [8, 512, 7, 7] 1,050,112
│ │ └─Sequential: 3-66 [8, 512, 7, 7] 2,360,832
│ │ └─Sequential: 3-67 [8, 2048, 7, 7] 1,054,720
│ │ └─ReLU: 3-68 [8, 2048, 7, 7] --
├─AdaptiveAvgPool2d: 1-6 [8, 2048, 7, 7] --
├─Linear: 1-7 [8, 4] 401,412
├─Softmax: 1-8 [8, 4] --
==========================================================================================
Total params: 23,936,004
Trainable params: 23,936,004
Non-trainable params: 0
Total mult-adds (G): 30.75
==========================================================================================
Input size (MB): 4.82
Forward/backward pass size (MB): 1335.15
Params size (MB): 95.74
Estimated Total Size (MB): 1435.71
==========================================================================================
模型训练
编写训练函数
def train(train_loader, model, loss_fn, optimizer):size = len(train_loader.dataset)num_batches = len(train_loader)train_loss, train_acc = 0, 0for x, y in train_loader:x, y = x.to(device), y.to(device)pred = model(x)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss /= num_batchestrain_acc /= sizereturn train_loss, train_acc
编写测试函数
def test(test_loader, model, loss_fn):size = len(test_loader.dataset)num_batches = len(test_loader)test_loss, test_acc = 0, 0for x, y in test_loader:x, y = x.to(device), y.to(device)pred = model(x)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchestest_acc /= sizereturn test_loss, test_acc
编写训练过程,由于ResNet50模型层数较多,并且使用pytorch无法加载K同学提供的tensorflow的权重,并且ResNet消除了网络退化问题,直接训练500个epoch,配合adam优化器和学习率衰减,初始化使用了1e-7这样的小学习率。
epochs = 200
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-7)
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 0.97**(epoch//10))train_loss, train_acc = [], []
test_loss, test_acc = [], []
for epoch in epochs:model.train()epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)scheduler.step()model.eval()with torch.no_grad():epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)train_loss.append(epoch_train_loss)train_acc.append(epoch_train_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)lr = optimizer.state_dict()['param_group'][0]['lr']print(f"Epoch: {epoch + 1}, lr: {lr}, TrainAcc: {epoch_train_acc*100:.1f}, TrainLoss: {epoch_train_loss:.3f}, TestAcc: {epoch_test_acc*100:.1f}, TestLoss: {epoch_test_loss:.3f}")
print("done")
训练了187个epoch后,训练集上的正确率已经达到100%,但是测试集上最好只有80%
Epoch: 1, Lr:1e-05, TrainAcc: 37.2, TrainLoss: 1.342, TestAcc: 32.7, TestLoss: 1.332
Epoch: 2, Lr:1e-05, TrainAcc: 52.4, TrainLoss: 1.225, TestAcc: 43.4, TestLoss: 1.271
Epoch: 3, Lr:1e-05, TrainAcc: 63.5, TrainLoss: 1.143, TestAcc: 52.2, TestLoss: 1.220
Epoch: 4, Lr:1e-05, TrainAcc: 64.8, TrainLoss: 1.126, TestAcc: 60.2, TestLoss: 1.191
Epoch: 5, Lr:9.5e-06, TrainAcc: 68.4, TrainLoss: 1.078, TestAcc: 54.0, TestLoss: 1.184
Epoch: 6, Lr:9.5e-06, TrainAcc: 72.6, TrainLoss: 1.040, TestAcc: 47.8, TestLoss: 1.270
Epoch: 7, Lr:9.5e-06, TrainAcc: 71.9, TrainLoss: 1.048, TestAcc: 57.5, TestLoss: 1.139
Epoch: 8, Lr:9.5e-06, TrainAcc: 79.4, TrainLoss: 0.988, TestAcc: 63.7, TestLoss: 1.117
Epoch: 9, Lr:9.5e-06, TrainAcc: 79.4, TrainLoss: 0.972, TestAcc: 64.6, TestLoss: 1.102
Epoch: 10, Lr:9.025e-06, TrainAcc: 81.4, TrainLoss: 0.958, TestAcc: 65.5, TestLoss: 1.095
Epoch: 11, Lr:9.025e-06, TrainAcc: 81.2, TrainLoss: 0.952, TestAcc: 65.5, TestLoss: 1.088
Epoch: 12, Lr:9.025e-06, TrainAcc: 82.7, TrainLoss: 0.931, TestAcc: 66.4, TestLoss: 1.096
Epoch: 13, Lr:9.025e-06, TrainAcc: 84.5, TrainLoss: 0.919, TestAcc: 69.9, TestLoss: 1.069
Epoch: 14, Lr:9.025e-06, TrainAcc: 84.7, TrainLoss: 0.913, TestAcc: 61.9, TestLoss: 1.096
Epoch: 15, Lr:8.573749999999999e-06, TrainAcc: 86.1, TrainLoss: 0.896, TestAcc: 67.3, TestLoss: 1.072
Epoch: 16, Lr:8.573749999999999e-06, TrainAcc: 86.3, TrainLoss: 0.899, TestAcc: 67.3, TestLoss: 1.059
Epoch: 17, Lr:8.573749999999999e-06, TrainAcc: 88.1, TrainLoss: 0.886, TestAcc: 64.6, TestLoss: 1.076
Epoch: 18, Lr:8.573749999999999e-06, TrainAcc: 88.5, TrainLoss: 0.873, TestAcc: 68.1, TestLoss: 1.059
Epoch: 19, Lr:8.573749999999999e-06, TrainAcc: 89.2, TrainLoss: 0.882, TestAcc: 67.3, TestLoss: 1.062
Epoch: 20, Lr:8.1450625e-06, TrainAcc: 89.2, TrainLoss: 0.872, TestAcc: 69.9, TestLoss: 1.034
Epoch: 21, Lr:8.1450625e-06, TrainAcc: 88.9, TrainLoss: 0.862, TestAcc: 70.8, TestLoss: 1.037
Epoch: 22, Lr:8.1450625e-06, TrainAcc: 90.3, TrainLoss: 0.857, TestAcc: 70.8, TestLoss: 1.029
Epoch: 23, Lr:8.1450625e-06, TrainAcc: 92.5, TrainLoss: 0.844, TestAcc: 70.8, TestLoss: 1.016
Epoch: 24, Lr:8.1450625e-06, TrainAcc: 90.9, TrainLoss: 0.846, TestAcc: 68.1, TestLoss: 1.036
Epoch: 25, Lr:7.737809374999999e-06, TrainAcc: 92.3, TrainLoss: 0.841, TestAcc: 72.6, TestLoss: 1.012
Epoch: 26, Lr:7.737809374999999e-06, TrainAcc: 92.0, TrainLoss: 0.836, TestAcc: 69.9, TestLoss: 1.015
Epoch: 27, Lr:7.737809374999999e-06, TrainAcc: 92.7, TrainLoss: 0.833, TestAcc: 74.3, TestLoss: 1.000
Epoch: 28, Lr:7.737809374999999e-06, TrainAcc: 92.3, TrainLoss: 0.837, TestAcc: 76.1, TestLoss: 0.987
Epoch: 29, Lr:7.737809374999999e-06, TrainAcc: 93.8, TrainLoss: 0.817, TestAcc: 71.7, TestLoss: 1.009
Epoch: 30, Lr:7.350918906249998e-06, TrainAcc: 95.1, TrainLoss: 0.807, TestAcc: 71.7, TestLoss: 0.999
Epoch: 31, Lr:7.350918906249998e-06, TrainAcc: 95.4, TrainLoss: 0.806, TestAcc: 77.0, TestLoss: 0.987
Epoch: 32, Lr:7.350918906249998e-06, TrainAcc: 95.1, TrainLoss: 0.804, TestAcc: 74.3, TestLoss: 0.996
Epoch: 33, Lr:7.350918906249998e-06, TrainAcc: 95.1, TrainLoss: 0.801, TestAcc: 71.7, TestLoss: 1.009
Epoch: 34, Lr:7.350918906249998e-06, TrainAcc: 95.4, TrainLoss: 0.801, TestAcc: 76.1, TestLoss: 0.987
Epoch: 35, Lr:6.983372960937498e-06, TrainAcc: 96.9, TrainLoss: 0.790, TestAcc: 73.5, TestLoss: 0.985
Epoch: 36, Lr:6.983372960937498e-06, TrainAcc: 96.5, TrainLoss: 0.797, TestAcc: 75.2, TestLoss: 0.999
Epoch: 37, Lr:6.983372960937498e-06, TrainAcc: 96.9, TrainLoss: 0.789, TestAcc: 77.9, TestLoss: 0.980
Epoch: 38, Lr:6.983372960937498e-06, TrainAcc: 95.6, TrainLoss: 0.795, TestAcc: 74.3, TestLoss: 0.991
Epoch: 39, Lr:6.983372960937498e-06, TrainAcc: 97.1, TrainLoss: 0.781, TestAcc: 77.9, TestLoss: 0.974
Epoch: 40, Lr:6.634204312890623e-06, TrainAcc: 96.7, TrainLoss: 0.784, TestAcc: 75.2, TestLoss: 0.990
Epoch: 41, Lr:6.634204312890623e-06, TrainAcc: 97.3, TrainLoss: 0.782, TestAcc: 74.3, TestLoss: 0.996
Epoch: 42, Lr:6.634204312890623e-06, TrainAcc: 98.0, TrainLoss: 0.777, TestAcc: 76.1, TestLoss: 0.989
Epoch: 43, Lr:6.634204312890623e-06, TrainAcc: 97.3, TrainLoss: 0.781, TestAcc: 73.5, TestLoss: 0.996
Epoch: 44, Lr:6.634204312890623e-06, TrainAcc: 97.6, TrainLoss: 0.778, TestAcc: 78.8, TestLoss: 0.970
Epoch: 45, Lr:6.302494097246091e-06, TrainAcc: 97.8, TrainLoss: 0.777, TestAcc: 78.8, TestLoss: 0.963
Epoch: 46, Lr:6.302494097246091e-06, TrainAcc: 97.6, TrainLoss: 0.777, TestAcc: 77.0, TestLoss: 0.979
Epoch: 47, Lr:6.302494097246091e-06, TrainAcc: 98.2, TrainLoss: 0.772, TestAcc: 75.2, TestLoss: 0.983
Epoch: 48, Lr:6.302494097246091e-06, TrainAcc: 97.6, TrainLoss: 0.774, TestAcc: 75.2, TestLoss: 0.978
Epoch: 49, Lr:6.302494097246091e-06, TrainAcc: 98.0, TrainLoss: 0.773, TestAcc: 75.2, TestLoss: 0.968
Epoch: 50, Lr:5.987369392383788e-06, TrainAcc: 98.0, TrainLoss: 0.771, TestAcc: 76.1, TestLoss: 0.975
Epoch: 51, Lr:5.987369392383788e-06, TrainAcc: 97.8, TrainLoss: 0.773, TestAcc: 77.0, TestLoss: 0.972
Epoch: 52, Lr:5.987369392383788e-06, TrainAcc: 98.0, TrainLoss: 0.769, TestAcc: 75.2, TestLoss: 0.984
Epoch: 53, Lr:5.987369392383788e-06, TrainAcc: 98.0, TrainLoss: 0.770, TestAcc: 73.5, TestLoss: 0.992
Epoch: 54, Lr:5.987369392383788e-06, TrainAcc: 98.0, TrainLoss: 0.769, TestAcc: 77.0, TestLoss: 0.974
Epoch: 55, Lr:5.688000922764597e-06, TrainAcc: 97.6, TrainLoss: 0.776, TestAcc: 76.1, TestLoss: 0.994
Epoch: 56, Lr:5.688000922764597e-06, TrainAcc: 98.2, TrainLoss: 0.767, TestAcc: 73.5, TestLoss: 0.999
Epoch: 57, Lr:5.688000922764597e-06, TrainAcc: 98.7, TrainLoss: 0.764, TestAcc: 73.5, TestLoss: 0.993
Epoch: 58, Lr:5.688000922764597e-06, TrainAcc: 98.5, TrainLoss: 0.761, TestAcc: 77.9, TestLoss: 0.977
Epoch: 59, Lr:5.688000922764597e-06, TrainAcc: 98.5, TrainLoss: 0.764, TestAcc: 73.5, TestLoss: 0.993
Epoch: 60, Lr:5.403600876626367e-06, TrainAcc: 98.5, TrainLoss: 0.766, TestAcc: 77.0, TestLoss: 0.973
Epoch: 61, Lr:5.403600876626367e-06, TrainAcc: 98.2, TrainLoss: 0.765, TestAcc: 75.2, TestLoss: 0.989
Epoch: 62, Lr:5.403600876626367e-06, TrainAcc: 98.7, TrainLoss: 0.760, TestAcc: 75.2, TestLoss: 0.978
Epoch: 63, Lr:5.403600876626367e-06, TrainAcc: 99.1, TrainLoss: 0.762, TestAcc: 74.3, TestLoss: 0.986
Epoch: 64, Lr:5.403600876626367e-06, TrainAcc: 99.1, TrainLoss: 0.757, TestAcc: 76.1, TestLoss: 0.986
Epoch: 65, Lr:5.133420832795049e-06, TrainAcc: 99.3, TrainLoss: 0.757, TestAcc: 78.8, TestLoss: 0.955
Epoch: 66, Lr:5.133420832795049e-06, TrainAcc: 99.3, TrainLoss: 0.755, TestAcc: 77.9, TestLoss: 0.969
Epoch: 67, Lr:5.133420832795049e-06, TrainAcc: 98.5, TrainLoss: 0.759, TestAcc: 77.0, TestLoss: 0.974
Epoch: 68, Lr:5.133420832795049e-06, TrainAcc: 99.3, TrainLoss: 0.754, TestAcc: 77.0, TestLoss: 0.974
Epoch: 69, Lr:5.133420832795049e-06, TrainAcc: 99.1, TrainLoss: 0.755, TestAcc: 75.2, TestLoss: 0.982
Epoch: 70, Lr:4.876749791155296e-06, TrainAcc: 99.3, TrainLoss: 0.756, TestAcc: 74.3, TestLoss: 0.986
Epoch: 71, Lr:4.876749791155296e-06, TrainAcc: 99.3, TrainLoss: 0.755, TestAcc: 77.9, TestLoss: 0.961
Epoch: 72, Lr:4.876749791155296e-06, TrainAcc: 99.3, TrainLoss: 0.755, TestAcc: 80.5, TestLoss: 0.953
Epoch: 73, Lr:4.876749791155296e-06, TrainAcc: 99.3, TrainLoss: 0.754, TestAcc: 77.0, TestLoss: 0.960
Epoch: 74, Lr:4.876749791155296e-06, TrainAcc: 99.3, TrainLoss: 0.755, TestAcc: 76.1, TestLoss: 0.966
Epoch: 75, Lr:4.632912301597531e-06, TrainAcc: 99.3, TrainLoss: 0.755, TestAcc: 73.5, TestLoss: 0.975
Epoch: 76, Lr:4.632912301597531e-06, TrainAcc: 98.9, TrainLoss: 0.758, TestAcc: 77.0, TestLoss: 0.972
Epoch: 77, Lr:4.632912301597531e-06, TrainAcc: 99.3, TrainLoss: 0.753, TestAcc: 75.2, TestLoss: 0.981
Epoch: 78, Lr:4.632912301597531e-06, TrainAcc: 99.6, TrainLoss: 0.753, TestAcc: 75.2, TestLoss: 0.982
Epoch: 79, Lr:4.632912301597531e-06, TrainAcc: 99.6, TrainLoss: 0.751, TestAcc: 77.9, TestLoss: 0.955
Epoch: 80, Lr:4.401266686517654e-06, TrainAcc: 99.3, TrainLoss: 0.753, TestAcc: 75.2, TestLoss: 0.978
Epoch: 81, Lr:4.401266686517654e-06, TrainAcc: 98.9, TrainLoss: 0.760, TestAcc: 76.1, TestLoss: 0.975
Epoch: 82, Lr:4.401266686517654e-06, TrainAcc: 99.1, TrainLoss: 0.755, TestAcc: 77.0, TestLoss: 0.970
Epoch: 83, Lr:4.401266686517654e-06, TrainAcc: 99.8, TrainLoss: 0.750, TestAcc: 76.1, TestLoss: 0.970
Epoch: 84, Lr:4.401266686517654e-06, TrainAcc: 99.8, TrainLoss: 0.749, TestAcc: 75.2, TestLoss: 0.967
Epoch: 85, Lr:4.181203352191771e-06, TrainAcc: 99.8, TrainLoss: 0.750, TestAcc: 75.2, TestLoss: 0.972
Epoch: 86, Lr:4.181203352191771e-06, TrainAcc: 99.8, TrainLoss: 0.749, TestAcc: 76.1, TestLoss: 0.955
Epoch: 87, Lr:4.181203352191771e-06, TrainAcc: 99.6, TrainLoss: 0.753, TestAcc: 77.9, TestLoss: 0.964
Epoch: 88, Lr:4.181203352191771e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 77.0, TestLoss: 0.961
Epoch: 89, Lr:4.181203352191771e-06, TrainAcc: 99.6, TrainLoss: 0.752, TestAcc: 77.0, TestLoss: 0.971
Epoch: 90, Lr:3.972143184582182e-06, TrainAcc: 99.8, TrainLoss: 0.751, TestAcc: 76.1, TestLoss: 0.976
Epoch: 91, Lr:3.972143184582182e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 75.2, TestLoss: 0.979
Epoch: 92, Lr:3.972143184582182e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 75.2, TestLoss: 0.966
Epoch: 93, Lr:3.972143184582182e-06, TrainAcc: 99.8, TrainLoss: 0.749, TestAcc: 78.8, TestLoss: 0.951
Epoch: 94, Lr:3.972143184582182e-06, TrainAcc: 99.6, TrainLoss: 0.750, TestAcc: 74.3, TestLoss: 0.978
Epoch: 95, Lr:3.7735360253530726e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 77.0, TestLoss: 0.958
Epoch: 96, Lr:3.7735360253530726e-06, TrainAcc: 99.8, TrainLoss: 0.749, TestAcc: 73.5, TestLoss: 0.981
Epoch: 97, Lr:3.7735360253530726e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 74.3, TestLoss: 0.981
Epoch: 98, Lr:3.7735360253530726e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 74.3, TestLoss: 0.977
Epoch: 99, Lr:3.7735360253530726e-06, TrainAcc: 99.8, TrainLoss: 0.749, TestAcc: 75.2, TestLoss: 0.972
Epoch: 100, Lr:3.584859224085419e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 77.0, TestLoss: 0.966
Epoch: 101, Lr:3.584859224085419e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 72.6, TestLoss: 0.976
Epoch: 102, Lr:3.584859224085419e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 74.3, TestLoss: 0.975
Epoch: 103, Lr:3.584859224085419e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 76.1, TestLoss: 0.960
Epoch: 104, Lr:3.584859224085419e-06, TrainAcc: 99.8, TrainLoss: 0.750, TestAcc: 73.5, TestLoss: 0.979
Epoch: 105, Lr:3.4056162628811484e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 75.2, TestLoss: 0.971
Epoch: 106, Lr:3.4056162628811484e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 76.1, TestLoss: 0.969
Epoch: 107, Lr:3.4056162628811484e-06, TrainAcc: 99.8, TrainLoss: 0.749, TestAcc: 75.2, TestLoss: 0.973
Epoch: 108, Lr:3.4056162628811484e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 75.2, TestLoss: 0.959
Epoch: 109, Lr:3.4056162628811484e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 77.9, TestLoss: 0.952
Epoch: 110, Lr:3.2353354497370905e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 76.1, TestLoss: 0.973
Epoch: 111, Lr:3.2353354497370905e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 76.1, TestLoss: 0.981
Epoch: 112, Lr:3.2353354497370905e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 77.9, TestLoss: 0.954
Epoch: 113, Lr:3.2353354497370905e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 74.3, TestLoss: 0.971
Epoch: 114, Lr:3.2353354497370905e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 77.0, TestLoss: 0.961
Epoch: 115, Lr:3.073568677250236e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 74.3, TestLoss: 0.969
Epoch: 116, Lr:3.073568677250236e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 78.8, TestLoss: 0.960
Epoch: 117, Lr:3.073568677250236e-06, TrainAcc: 99.8, TrainLoss: 0.746, TestAcc: 76.1, TestLoss: 0.964
Epoch: 118, Lr:3.073568677250236e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 75.2, TestLoss: 0.969
Epoch: 119, Lr:3.073568677250236e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 77.0, TestLoss: 0.963
Epoch: 120, Lr:2.919890243387724e-06, TrainAcc: 99.6, TrainLoss: 0.748, TestAcc: 75.2, TestLoss: 0.968
Epoch: 121, Lr:2.919890243387724e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 72.6, TestLoss: 0.961
Epoch: 122, Lr:2.919890243387724e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 75.2, TestLoss: 0.969
Epoch: 123, Lr:2.919890243387724e-06, TrainAcc: 99.6, TrainLoss: 0.751, TestAcc: 72.6, TestLoss: 0.974
Epoch: 124, Lr:2.919890243387724e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 74.3, TestLoss: 0.974
Epoch: 125, Lr:2.7738957312183377e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 74.3, TestLoss: 0.975
Epoch: 126, Lr:2.7738957312183377e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 74.3, TestLoss: 0.972
Epoch: 127, Lr:2.7738957312183377e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 75.2, TestLoss: 0.967
Epoch: 128, Lr:2.7738957312183377e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 72.6, TestLoss: 0.980
Epoch: 129, Lr:2.7738957312183377e-06, TrainAcc: 99.6, TrainLoss: 0.749, TestAcc: 74.3, TestLoss: 0.980
Epoch: 130, Lr:2.6352009446574206e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 73.5, TestLoss: 0.985
Epoch: 131, Lr:2.6352009446574206e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 74.3, TestLoss: 0.980
Epoch: 132, Lr:2.6352009446574206e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 76.1, TestLoss: 0.967
Epoch: 133, Lr:2.6352009446574206e-06, TrainAcc: 99.8, TrainLoss: 0.746, TestAcc: 75.2, TestLoss: 0.982
Epoch: 134, Lr:2.6352009446574206e-06, TrainAcc: 99.8, TrainLoss: 0.746, TestAcc: 76.1, TestLoss: 0.982
Epoch: 135, Lr:2.5034408974245495e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 77.0, TestLoss: 0.968
Epoch: 136, Lr:2.5034408974245495e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 75.2, TestLoss: 0.970
Epoch: 137, Lr:2.5034408974245495e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 77.0, TestLoss: 0.976
Epoch: 138, Lr:2.5034408974245495e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 76.1, TestLoss: 0.954
Epoch: 139, Lr:2.5034408974245495e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 75.2, TestLoss: 0.965
Epoch: 140, Lr:2.378268852553322e-06, TrainAcc: 99.8, TrainLoss: 0.746, TestAcc: 76.1, TestLoss: 0.959
Epoch: 141, Lr:2.378268852553322e-06, TrainAcc: 99.6, TrainLoss: 0.749, TestAcc: 78.8, TestLoss: 0.950
Epoch: 142, Lr:2.378268852553322e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 78.8, TestLoss: 0.943
Epoch: 143, Lr:2.378268852553322e-06, TrainAcc: 99.8, TrainLoss: 0.746, TestAcc: 77.9, TestLoss: 0.948
Epoch: 144, Lr:2.378268852553322e-06, TrainAcc: 99.8, TrainLoss: 0.746, TestAcc: 77.0, TestLoss: 0.965
Epoch: 145, Lr:2.2593554099256557e-06, TrainAcc: 99.6, TrainLoss: 0.749, TestAcc: 80.5, TestLoss: 0.947
Epoch: 146, Lr:2.2593554099256557e-06, TrainAcc: 99.8, TrainLoss: 0.748, TestAcc: 80.5, TestLoss: 0.946
Epoch: 147, Lr:2.2593554099256557e-06, TrainAcc: 99.6, TrainLoss: 0.749, TestAcc: 76.1, TestLoss: 0.953
Epoch: 148, Lr:2.2593554099256557e-06, TrainAcc: 100.0, TrainLoss: 0.745, TestAcc: 75.2, TestLoss: 0.976
Epoch: 149, Lr:2.2593554099256557e-06, TrainAcc: 100.0, TrainLoss: 0.745, TestAcc: 77.0, TestLoss: 0.957
Epoch: 150, Lr:2.146387639429373e-06, TrainAcc: 100.0, TrainLoss: 0.745, TestAcc: 72.6, TestLoss: 0.964
Epoch: 151, Lr:2.146387639429373e-06, TrainAcc: 100.0, TrainLoss: 0.745, TestAcc: 74.3, TestLoss: 0.972
Epoch: 152, Lr:2.146387639429373e-06, TrainAcc: 100.0, TrainLoss: 0.745, TestAcc: 75.2, TestLoss: 0.968
Epoch: 153, Lr:2.146387639429373e-06, TrainAcc: 100.0, TrainLoss: 0.746, TestAcc: 77.9, TestLoss: 0.955
Epoch: 154, Lr:2.146387639429373e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 77.0, TestLoss: 0.951
Epoch: 155, Lr:2.039068257457904e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 78.8, TestLoss: 0.952
Epoch: 156, Lr:2.039068257457904e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 78.8, TestLoss: 0.946
Epoch: 157, Lr:2.039068257457904e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 78.8, TestLoss: 0.956
Epoch: 158, Lr:2.039068257457904e-06, TrainAcc: 100.0, TrainLoss: 0.746, TestAcc: 76.1, TestLoss: 0.964
Epoch: 159, Lr:2.039068257457904e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 77.0, TestLoss: 0.956
Epoch: 160, Lr:1.937114844585009e-06, TrainAcc: 99.8, TrainLoss: 0.747, TestAcc: 77.0, TestLoss: 0.960
Epoch: 161, Lr:1.937114844585009e-06, TrainAcc: 100.0, TrainLoss: 0.745, TestAcc: 77.0, TestLoss: 0.965
Epoch: 162, Lr:1.937114844585009e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 77.0, TestLoss: 0.960
Epoch: 163, Lr:1.937114844585009e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 78.8, TestLoss: 0.959
Epoch: 164, Lr:1.937114844585009e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 76.1, TestLoss: 0.949
Epoch: 165, Lr:1.8402591023557584e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 76.1, TestLoss: 0.953
Epoch: 166, Lr:1.8402591023557584e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 80.5, TestLoss: 0.944
Epoch: 167, Lr:1.8402591023557584e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 76.1, TestLoss: 0.954
Epoch: 168, Lr:1.8402591023557584e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 78.8, TestLoss: 0.951
Epoch: 169, Lr:1.8402591023557584e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 77.9, TestLoss: 0.963
Epoch: 170, Lr:1.7482461472379704e-06, TrainAcc: 100.0, TrainLoss: 0.747, TestAcc: 75.2, TestLoss: 0.957
Epoch: 171, Lr:1.7482461472379704e-06, TrainAcc: 100.0, TrainLoss: 0.745, TestAcc: 77.0, TestLoss: 0.953
Epoch: 172, Lr:1.7482461472379704e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 75.2, TestLoss: 0.971
Epoch: 173, Lr:1.7482461472379704e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 76.1, TestLoss: 0.952
Epoch: 174, Lr:1.7482461472379704e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 78.8, TestLoss: 0.944
Epoch: 175, Lr:1.6608338398760719e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 78.8, TestLoss: 0.939
Epoch: 176, Lr:1.6608338398760719e-06, TrainAcc: 100.0, TrainLoss: 0.745, TestAcc: 79.6, TestLoss: 0.943
Epoch: 177, Lr:1.6608338398760719e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 77.0, TestLoss: 0.963
Epoch: 178, Lr:1.6608338398760719e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 77.9, TestLoss: 0.956
Epoch: 179, Lr:1.6608338398760719e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 77.0, TestLoss: 0.957
Epoch: 180, Lr:1.577792147882268e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 78.8, TestLoss: 0.955
Epoch: 181, Lr:1.577792147882268e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 75.2, TestLoss: 0.955
Epoch: 182, Lr:1.577792147882268e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 76.1, TestLoss: 0.952
Epoch: 183, Lr:1.577792147882268e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 77.0, TestLoss: 0.955
Epoch: 184, Lr:1.577792147882268e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 77.9, TestLoss: 0.946
Epoch: 185, Lr:1.4989025404881547e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 78.8, TestLoss: 0.947
Epoch: 186, Lr:1.4989025404881547e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 77.9, TestLoss: 0.953
Epoch: 187, Lr:1.4989025404881547e-06, TrainAcc: 100.0, TrainLoss: 0.744, TestAcc: 76.1, TestLoss: 0.949
模型效果展示
epoch_ranges = range(187)
plt.figure(figsize=(20,6))
plt.subplot(121)
plt.plot(epoch_ranges, train_loss, label='train loss')
plt.plot(epoch_ranges, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.figure(figsize=(20,6))
plt.subplot(122)
plt.plot(epoch_ranges, train_acc, label='train accuracy')
plt.plot(epoch_ranges, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')
损失曲线
准确率曲线
总结与心得体会
一开始使用了0.1这样的大学习率,训练完全无法进行。后来发现loss初始值就非常小,只有1.5左右,于是改用小学习率,模型才train起来。
IdentityBlock和ConvBlock中都有1x1卷积压缩特征图后进行实际的特征提取/处理操作然后再通过1x1卷积恢复原体积的特征图,目前对这个缩小后再放大的操作非常迷茫,既然特征的计算已经在低维度进行过了,再基于此放大特征也不可能有新的特征出现,反而会多了很多噪声,是不是这个噪声是必须。
由于没有预训练的权重 ,加上数据集的数量有限,导致模型在训练集上出现了过拟合的现象,应该可以通过增大数据集来改善。