3.2.微调

微调

​ 对于一些样本数量有限的数据集,如果使用较大的模型,可能很快过拟合,较小的模型可能效果不好。这个问题的一个解决方案是收集更多数据,但其实在很多情况下这是很难做到的。

​ 另一种方法就是迁移学习(transfer learning),将源数据集学到地知识迁移到目标数据集,例如,我们只想识别椅子,只有100把椅子,每把椅子的1000张不同角度的图像,尽管ImageNet数据集中大多数图像与椅子无关,但在次数据集上训练的模型可能会提取更通用的图像特征(可以理解为越底层的layer提取的特征越通用),这有助于识别边缘、纹理、形状和对象组合,也可能有效地识别椅子。

在这里插入图片描述

1. 步骤

​ 微调是迁移学习中的常见技巧,步骤如下:

  1. 在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型
  2. 创建一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数(输出层除外)。我们假定这些模型参数包含从源数据集中学到的知识,这些知识也将适用于目标数据集。我们还假设源模型的输出层与源数据集的标签密切相关;因此不在目标模型中使用该层。
  3. 向目标模型添加输出层,其输出数是目标数据集中的类别数。然后随机初始化该层的模型参数。
  4. 在目标数据集(如椅子数据集)上训练目标模型。输出层将从头开始进行训练,而所有其他层的参数将根据源模型的参数进行微调。

在这里插入图片描述

1.1 目标模型的训练:

​ 是一个正在目标数据集上的正常训练任务,但使用更强的正则化(参数变化不大):

  • 更小的学习率
  • 更少的数据迭代

​ 如果源数据集远复杂于目标数据,通常微调效果更好

1.2 重用分类器权重

​ 有些时候源数据集可能也有目标数据中的部分标号,比如ImageNet里可能有椅子这一标签,那么可以使用预训练好的模型分类器中对应标号中对应的向量来做初始化(就直接copy)

1.3 固定一些层

​ 通常而言,神经网络中低层次的特征更加通用,高层次的特征则更跟数据集相关。

​ 那么可以固定底部一些层的参数,不参与更新,这样也能有更强的正则。

2.热狗识别

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import torch_directmldevice = torch_directml.device()
# @save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)
d2l.plt.show()# 使用RGB通道的均值和标准差,以标准化每个通道 ,因为预训练的模型做了这个
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])# 先随机裁剪,并变为224 * 224 的图形,因为预训练模型输入是这个
train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])
# 将图像的高度和宽度都缩放到256像素,然后裁剪中央 224 * 224的区域来作为输入
test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])# 下载模型,pretrained参数已被弃用,使用weights来获取与训练模型
pretrained_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
print(pretrained_net.fc)  # 预训练最后一层为输出层finetune_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1).to(device)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2).to(device)
nn.init.xavier_uniform_(finetune_net.fc.weight)def train_batch_ch13(net, X, y, loss, trainer, devices):"""用多GPU进行小批量训练"""if isinstance(X, list):# 微调BERT中所需X = [x.to(devices) for x in X]else:X = X.to(devices)y = y.to(devices)net.train()trainer.zero_grad()pred = net(X)l = loss(pred, y)l.sum().backward()trainer.step()train_loss_sum = l.sum()train_acc_sum = d2l.accuracy(pred, y)return train_loss_sum, train_acc_sum# @save 多GPU的,把参数devices改成device了,本来是个列表
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,device):"""用多GPU进行模型训练"""timer, num_batches = d2l.Timer(), len(train_iter)animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],legend=['train loss', 'train acc', 'test acc'])# net = nn.DataParallel(net, device_ids=devices).to(devices[0])for epoch in range(num_epochs):# 4个维度:储存训练损失,训练准确度,实例数,特点数metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_iter):timer.start()l, acc = train_batch_ch13(net, features, labels, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[3],None))test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc 'f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on 'f'{str(device)}')# 微调
# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)# devices = d2l.try_all_gpus()device = torch_directml.device()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]# 最后一层使用10倍学习率trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,device)train_fine_tuning(finetune_net, 5e-5)
d2l.plt.show()

loss 0.270, train acc 0.899, test acc 0.948
232.6 examples/sec on privateuseone:0

一次训练效果就很好了,而且后续训练很平滑,没有过拟合。

在这里插入图片描述

​ 如果初始化为随机值:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3280540.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

FFplay介绍及命令使用指南

😎 作者介绍:欢迎来到我的主页👈,我是程序员行者孙,一个热爱分享技术的制能工人。计算机本硕,人工制能研究生。公众号:AI Sun(领取大厂面经等资料),欢迎加我的…

微软Win11 24H2最新可选更新补丁26100.1301来袭!

系统之家于7月31日发出最新报道,微软针对Win11 24H2用户推出七月最新的可选更新KB5040529,本次更新为开始菜单引入了全新的账号管理器,也改进了任务栏上的小组件图标。接下来跟随系统之家小编来看看本次更新的详细内容吧!【推荐下…

不同类型游戏安全风险对抗概览(下)| FPS以及小游戏等外挂问题,一文读懂!

FPS 游戏安全问题 由于射击类游戏本身需要大量数值计算,游戏方会将部分计算存放于本地客户端,而这为外挂攻击者提供了攻击的温床。可以说,射击类游戏是所有游戏中被外挂攻击最为频繁的游戏类型。 根据网易易盾游戏安全部门检测数据显示&#…

【排序算法】Java实现三大非比较排序:计数排序、桶排序、基数排序

非比较排序概念 非比较排序是一种排序算法,它不通过比较元素之间的大小关系来进行排序,而是基于元素的特征或属性进行排序。这种方法在特定情况下可以比比较排序方法(如快速排序、归并排序等)更有效率,尤其是在处理大…

【原创】java+ssm+mysql医生信息管理系统设计与实现

个人主页:程序员杨工 个人简介:从事软件开发多年,前后端均有涉猎,具有丰富的开发经验 博客内容:全栈开发,分享Java、Python、Php、小程序、前后端、数据库经验和实战 开发背景: 随着信息技术的…

详解线程的几种状态?

详解线程的几种状态? 1. 新建状态(New)2. 就绪状态(Runnable)3. 运行状态(Running)4. 阻塞状态(Blocked)5. 死亡状态(Dead) 💖The Begin&#x1…

获客工具大揭秘:为何它能让获客如此轻松?

你是不是也觉得,现在的市场环境,获客越来越难了? 今天我要给大家分享一个实用且高效的获客工具,它简直是营销界的福音! 1、关键词搜索 关键词搜索功能是获客工具的基础,也是其重要性不可小觑的原因。 这…

go-zero框架入门---认识微服务以及环境的安装

什么是微服务 微服务是一种软件架构风格,它将一个大型应用程序拆分成多个小型的、独立部署的服务,每个服务实现单一业务功能。每个服务运行在自己的进程中,并通过轻量级的通信机制(通常是HTTP RESTful API)相互协作。…

ubuntu 使用 freeplane

在知乎在过这个问题后 思维导图工具freemind和freeplane的区别? - 知乎。我选择使用 freeplane 作为思维导图的绘制软件。理由不外乎系统受限,和开源软件。 直接在软件商店里搜索 mind ,其实也有其它的软件。第一个也蛮好用的。 安装 如果在…

【分享】HCIP-AI-EI Developer备考攻略

刚考完HCIP-AI-EI Developer就写了这篇热乎的笔记,主要是我在备考的时候发现网上没有相关经验帖,导致备考的时候心态不好。我从自身状态、考试介绍、备考建议、考试技巧等方面进行了总结,非常详细,希望我的这篇笔记能给大家提供一些帮助。 1 我的情况 备考前状态:学过一…

buu做题(11)

[CISCN2019 华东南赛区]Web11 抓个包可以发现是 Smarty框架 在页面可以观察到 一个 XFF头, 可以猜测注入点就在这 通过 if 标签执行命令 ,读取flag if system("cat /flag")}{/if} [极客大挑战 2019]FinalSQL 一个登录框, 上面的提示应该就是要你盲注了 点一下那…

Web : EL表达式 -15

EL表达式概述 EL 全名为Expression Language&#xff0c;用来替代<% %>脚本表达式。 基本结构为${表达式}。 获取数据 获取常量 <h1>获取常量</h1> ${123} ${123.32} ${"abc"} ${true} 获取变量 el会自动从四大作用域中搜寻域属性来使用 如果找不…

vue3后台管理系统 vue3+vite+pinia+element-plus+axios上

前言 项目安装与启动 使用vite作为项目脚手架 # pnpm pnpm create vite my-vue-app --template vue安装相应依赖 # sass pnpm i sass # vue-router pnpm i vue-router # element-plus pnpm i element-plus # element-plus/icon pnpm i element-plus/icons-vue安装element-…

《海军罪案调查处:起源》预告片介绍新角色莱罗伊·杰思罗·吉布斯

《海军罪案调查处&#xff1a;起源》的主演奥斯汀斯托威尔最近分享了这部备受期待的前传系列剧的一张新宣传照。虽然距离该剧上映还有几个月的时间&#xff0c;但这张照片将激起粉丝们的兴奋之情。 这张照片通过斯托维尔的官方社交账号分享&#xff0c;让观众们看到了年轻时的…

html+css+js前端作业和平精英官网1个页面(带js)

htmlcssjs前端作业和平精英官网1个页面&#xff08;带js&#xff09;有轮播图tab切换等功能 下载地址 https://download.csdn.net/download/qq_42431718/89597007 目录1 目录2 项目视频 htmlcssjs前端作业和平精英官网1个页面&#xff08;带js&#xff09; 页面1

国家超算互联网平台:模型服务体验与本地部署推理实践

目录 前言一、平台显卡选用1、显卡选择2、镜像选择3、实例列表4、登录服务器 二、平台模型服务【Stable Diffusion WebUI】体验1、模型运行2、端口映射配置3、体验测试 三、本地模型【Qwen1.5-7B-Chat】推理体验1、安装依赖2、加载模型3、定义提示消息4、获取model_inputs5、生…

前端-如何通过docker打包Vue服务成镜像并在本地运行(本地可以通过http://localhost:8080/访问前端服务)

1、下载安装docker&#xff0c;最好在vs code里安装docker的插件。 下载链接&#xff1a;https://www.docker.com/products/docker-desktop &#x1f389; Docker 简介和安装 - Docker 快速入门 - 易文档 (easydoc.net) 2、准备配置文件-dockerfile文件和nginx.conf文件 do…

【Redis 初阶】Redis 常见数据类型(Set、Zset、渐进式遍历、数据库管理)

一、Set 集合 集合类型也是保存多个字符串类型的元素的&#xff08;可以使用 json 格式让 string 也能存储结构化数据&#xff09;&#xff0c;但和列表类型不同的是&#xff0c;集合中&#xff1a; 元素之间是无序的。&#xff08;此处的 “无序” 是和 list 的有序相对应的…

Camera Raw:五阶段修图流程

在使用 Camera Raw 修图时&#xff0c;如果按照一定的流程来进行&#xff0c;可以大大提高工作效率。这里提出的五阶段修图流程&#xff0c;简单来说就是&#xff1a; 1、调亮度&#xff0c;定影调 2、还原校正修复 3、局部调整优化 4、调颜色&#xff0c;定色调 5、存储、输出…

【C语言】qsort详解——能给万物排序的神奇函数

&#x1f984;个人主页:小米里的大麦-CSDN博客 &#x1f38f;所属专栏:https://blog.csdn.net/huangcancan666/category_12718530.html ⚙️操作环境:Visual Studio 2022 目录 一、引言 二、qsort函数介绍 1.函数原型 2.参数说明 2.1比较函数 3.使用示例 3.1对一维数组进…