基于迁移学习的手势分类模型训练

1、基本原理介绍

       这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略),是在一个已经有过训练基础的模型上,用自己的数据集,进一步训练,使得这个模型能够完成我们需要的任务。

这么做有有这样几个显而易见的好处:

※  因为模型之前被训练过,所以初始参数不会是0,这样能够加速模型训练

※  因为预训练模型(什么是预训练模型下文会讲到)在其他数据集上训练过,而其他数据集往往和我们用的数据集存在一定的区别,所以这可以提高模型的泛化能力

※  通过迁移学习,可以将来自大规模数据的优势转移到小规模或新任务上,提高模型的表现和效果

2、预训练模型

        在进行迁移学习时,我们要先找到一个预训练模型。在分类任务领域,比较流行的如resnet系列、mobilenet系列(更轻量化)、vgg(系列)、efficientnet(系列)等等网络,都是比较常用且容易获得的预训练模型,这些模型都能够通过python直接下载。

        而且由于上述模型基本都是在ImageNet这一大规模,多分类类别的数据集上进行过训练的,所以对于简单的二分类等少数类别分类,能有较好的效果。

3、训练流程

迁移学习完整的训练流程和一般搭建神经网络的训练模型的流程基本类似:数据预处理->数据集的切分->加载预训练模型(搭建神经网络)->设置超参数/损失函数/优化器等->训练模型

3.1 模型训练

下面的代码是一个利用mobilenet网络训练得到的手势分类模型,该模型能够较准确的分类不同类别手势。

相关解释已在代码中注释说明。

from torchvision.models import mobilenet_v2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation# 定义数据预处理和增强器
transform = Compose([RandomHorizontalFlip(),  # 随机水平翻转RandomRotation(10),      # 随机旋转10度Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集并应用预处理和增强器
dataset = ImageFolder(root='data', transform=transform)
# 这里由于数据比较少,将所有数据集全部用来训练,得到的模型直接拿来用了,这其实不算是非常规范的操作,仅供参考# 定义网络结构
model = mobilenet_v2(pretrained=True)  # 加载预训练模型,也可以试试其他模型,效果差别挺大的
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 5)  # 假设是5分类问题,具体几分类,改这里的参数就行了# 将模型移动到设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)# 定义优化器和损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()# 定义训练循环
def train_model(model, criterion, optimizer, num_epochs, train_loader):for epoch in range(num_epochs):model.train()  # 设置模型为训练模式train_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()epoch_loss = train_loss / totalepoch_acc = 100. * correct / totalprint(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')# 创建训练集的DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)# 开始训练模型
train_model(model, criterion, optimizer, num_epochs=15, train_loader=train_loader)
torch.save(model, 'my_model(1).pth')

3.2 数据集文件结构

当然,你也可以自己定义读取数据集的data_loader类。

3.3 模型推理

这段代码是用训练得到的模型对一张图片进行推理测试的,如果需要对系列图片进行推理,评估模型效果,可自行修改,调用对应函数即可。

import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
def predict_image(image_path, model_path='my_model(1).pth'):image = Image.open(image_path).convert("RGB")# 对测试的图片进行预处理,需要和训练时处理的方式一样transform = Compose([Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])image_tensor = transform(image).unsqueeze(0)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')image_tensor = image_tensor.to(device)model = torch.load(model_path,map_location=device)model.eval()with torch.no_grad():output = model(image_tensor)_, predicted = torch.max(output.data, 1)  # 获得分类标记return predicted.item()
if __name__=="__main__":image_path = "test2/6.jpg"print(predict_image(image_path))

3.4 整体项目文件

4、补充说明

        这种利用迁移学习策略,进行少类别,不同类别特征差距小的任务需求来说,效果一般来说是比较好的。因为之前做过相关实验,准确率90%以上是很容易的,所以这里没有模型评估,生成混淆矩阵等过程。对于多类别分类,建议有完整的评估体系。

        上述使用的方法仅适用于分类任务,对于真正的目标检测如手势识别,直接使用该模型的问题是:由于无法定位手势的位置,所以导致识别不准确。

        本实验数据集是不同类别手势图片,为自制,不开源。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3267593.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Layui修改表格分页为英文

Layui修改表格分页为英文 1.前言2.Laypage属性 1.前言 主要记录初次使用Layui没有好好看官方文档踩坑,修改了源码才发现可以自定义 使用的Layui版本2.9.14 2.Laypage属性 Laypage属性中带的有自定义文本的属性 示例代码 table.render({.......page: {skipText: …

状态机 XState 使用

状态机 一般指的是有限状态机(Finite State Machine,FSM),又可以称为有限状态自动机(Finite State Automation,FSA),简称状态机,它是一个数学模型,表示有限个…

硬核科普:什么是网络准入控制系统|网络准入控制系统四大品牌介绍

网络准入控制系统(Network Access Control, NAC)是一种用于确保只有授权设备和用户才能接入网络的安全技术。 本文将介绍几种常用的网络准入控制系统,帮助您更好地了解如何选择适合您企业的NAC系统。 网络准入控制的重要性和作用 网络准入控…

java学习--练习题

在类中this.属赋值,则外部创建对象调用其值也会随之一样 package com.test01;/* author:我与java相爱相杀---c语言梦开始的地方 今天又是努力学习的一天!!!! */ /*1. 在Frock类中声明私有的静态属性currentNum[int类型…

idm软件最新破解版下载 idm永久激活码 IDM中文绿色特别版 idm下载器汉化版

在互联网时代,下载管理软件成为了我们日常使用电脑不可或缺的工具之一。说起下载工具,大家的第一反应可能是网盘、迅雷。但在PC端其实还有一个可以对标他们的软件——IDM,这是一个口碑炸裂的多线程下载工具。 Internet Download Manager&…

让你的设计更出色:10个最受欢迎的3D画图工具盘点

随着渲染工具的发生和客户对立体效果的要求越来越高,设计师应该能够及时用设计风格解释空间界面,全面使用3D画图工具进行展览设计。3D画图工具在建筑、工程、产品设计等行业使用不同的算法,为图像添加色调、质感等细节。不同类型的3D画图工具…

鸿蒙HarmonyOS【应用开发五、组件介绍】

✍️作者简介:小北编程(专注于HarmonyOS、Android、Java、Web、TCP/IP等技术方向) 🐳博客主页: 开源中国、稀土掘金、51cto博客、博客园、知乎、简书、慕课网、CSDN 🔔如果文章对您有一定的帮助请&#x1f…

Java之 jvm

jvm之管理内存 程序计数器:当前线程所执行的字节码的行号指示器。程序计数器是唯一一个不会出现 OutOfMemoryError 的内存区域,它的生命周期随着线程的创建而创建,随着线程的结束而死亡。Java虚拟机栈 方法调用 一个方法调用都会有对应的栈帧…

Redis - SpringDataRedis - RedisTemplate

目录 概述 创建项目 引入依赖 配置文件 测试代码 测试结果 数据序列化器 自定义RedisTemplate的序列化方式 测试报错 添加依赖后测试 存入一个 String 类型的数据 测试存入一个对象 优化 -- 手动序列化 测试存入一个Hash 总结: 概述 SpringData 是 S…

《Milvus Cloud向量数据库指南》——BGE-M3:多功能、多语言、多粒度的文本表示学习模型

引言 在自然语言处理(NLP)领域,随着大数据时代的到来,对文本信息的精准处理与高效检索成为了研究热点。BERT(Bidirectional Encoder Representations from Transformers)作为近年来NLP领域的里程碑式模型,以其强大的上下文理解能力在多项任务中取得了显著成效。然而,面…

刘纪鹏:“3万亿资金将股市拉升至4000点”,你能赚?

本周刘纪鹏提出了一个观点:花费3万亿资金将股市拉升至4000点,有望带来25万亿的财富增长。 3万亿的投入与25万亿的潜在增长确实令人心动。股市并非简单的投入资金就能涨,还需要考虑市场情绪、经济基本面等因素的影响。举个例子,某个…

【leetcode 详解】找出区分值(C++思路详解):这【中等】题怎么十分钟就写完了?

评价:就笔者的感觉吧,leetcode上难度标为“中等”的题目往往不是说需要什么高深的算法来解决,但基本都涉及到 “问题转化” 的能力要求,换言之,难点往往在于思维。 tip:要解决这类问题,笔者推荐…

python3.10.4——Windows环境安装

python下载官网:https://www.python.org/downloads/ 如果安装在C盘,需要右键→选择“以管理员身份运行” 勾选2个按钮,选择自定义安装 全部选择,点击Next 更改安装路径 命令行检查python是否安装成功: 出现版本号说明…

如何使用C#自制一个Windows安装包

原文链接:https://www.cnblogs.com/zhaotianff/p/17387496.html 以前都在用InstallShield制作安装包,基本需求是能满足的,但也有一些缺点: 1、界面不能完全定制 2、不能直接调用代码里的功能 平常使用一些其它软件,…

数据结构(Java):Map集合Set集合哈希表

目录 1、介绍 1.1 Map和Set 1.2 模型 2、Map集合 2.1 Map集合说明 2.2 Map.Entry<K&#xff0c;V> 2.3 Map常用方法 2.4 Map注意事项及实现类 3、Set集合 3.1 Set集合说明 3.2 Set常用方法 3.3 Set注意事项及其实现类 4、TreeMap&TreeSet 4.1 集合类TreeM…

嵌入式中什么是三次握手

在开始前刚好我有一些资料&#xff0c;是我根据网友给的问题精心整理了一份「嵌入式的资料从专业入门到高级教程」&#xff0c;点个关注在评论区回复“666”之后私信回复“666”&#xff0c;全部无偿共享给大家&#xff01;&#xff01;&#xff01; 在网络数据传输中&#xf…

pytorch3d的安装

在这个网址中&#xff0c;下载对应的pytorch3d安装包 https://anaconda.org/pytorch3d/pytorch3d/files下载完成后使用下面命令进行安装 conda install ./pytorch3d-0.7.7-py39_cu118_pyt201.tar.bz2

可见性::

目录 定义&#xff1a; 解决方法&#xff1a; ①使用synchronized实现缓存和内存的同步 修改一&#xff1a; 加入语句&#xff1a; 代码&#xff1a; 修改2&#xff1a; 在代码块中加入&#xff1a; 代码&#xff1a; 执行结果&#xff1a; 原因&#xff1a; ②使用…

Linux--Socket 编程 UDP(简单的回显服务器和客户端代码)

目录 0.上篇文章 1.V1 版本 - echo server 1.1认识接口 1.2实现 V1 版本 - echo server&#xff08;细节&#xff09; 1.3添加的日志系统&#xff08;代码&#xff09; 1.4 解析网络地址 1.5 禁止拷贝逻辑&#xff08;基类&#xff09; 1.6 服务端逻辑 &#xff08;代码&…

【C/C++】printf和cout的区别

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; &#x1f525;c系列专栏&#xff1a;C/C零基础到精通 &#x1f525; 给大…