主干网络篇 | YOLOv8更换主干网络之ShuffleNetV2(包括完整代码+添加步骤+网络结构图)

前言:Hello大家好,我是小哥谈。ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力!~🌈  

     目录

🚀1. 基础概念

🚀2.网络结构

🚀3.添加步骤

🚀4.改进方法

🍀🍀步骤1:block.py文件修改

🍀🍀步骤2:__init__.py文件修改

🍀🍀步骤3:tasks.py文件修改

🍀🍀步骤4:创建自定义yaml文件

🍀🍀步骤5:新建train.py文件

🍀🍀步骤6:模型训练测试

🚀1. 基础概念

ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。

ShuffleNetV2的主要特点包括:

  1. 分组卷积:通过将输入通道分成多个组,并在组内进行卷积操作,减少了计算量和参数数量。
  2. 逐点卷积:使用1x1的卷积核进行逐点卷积,用于调整通道数和特征图的维度。
  3. 通道重排:通过将输入特征图按通道进行重排,实现信息的混洗和交互,增强了特征的表达能力。
  4. 瓶颈结构:采用瓶颈结构,即先降维再升维,减少了计算量和参数数量。
  5. 网络设计:ShuffleNet V2通过堆叠多个ShuffleNet单元来构建整个网络,可以根据任务的需求进行不同层数和宽度的配置。

ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力。

shuffleNetV2这篇论文比较硬核,提出了不少新的思想,推荐大家可以看看论文原文。主要思想包括:

  • 模型的计算复杂度不能只看FLOPs,还需要参考一些其他的指标
  • 作者提出了4条如何设计高效网络的准则
  • 基于该准则提出了新的block设置

FLOPS网上有两种:FLOPS和 FLOPs

FLOPS:全大写,指每秒浮点运算次数,可以理解为计算的速度,是衡量硬件性能的一个指标 (硬件)
FLOPs:s小写,指浮点运算数,理解为计算量,可以用来衡量算法/模型的复杂度,(模型)在论文中常用GFLOPs(1 GFLOPs = 10^9FLOPs)

 ShuffleNetV2网络结构:

 原理图:

其中,a、b为ShuffleNetV1原理图,c、d为ShuffleNetV2原理图。

论文题目:《ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design》

论文地址:  https://arxiv.org/pdf/1807.11164.pdf

代码实现:  GitHub - megvii-model/ShuffleNet-Series 


🚀2.网络结构

本文的改进是基于YOLOv8,关于其网络结构具体如下图所示:

YOLOv8官方仓库地址:

GitHub - ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite

针对本文的改进,作者将所使用的含有预训练权重文件的YOLOv8完整源码进行了上传,大家可在我的“资源”中自行下载。  


🚀3.添加步骤

针对本文的改进,具体步骤如下所示:👇

步骤1:block.py文件修改

步骤2:__init__.py文件修改

步骤3:tasks.py文件修改

步骤4:创建自定义yaml文件

步骤5:新建train.py文件

步骤6:模型训练测试


🚀4.改进方法

🍀🍀步骤1:block.py文件修改

在源码中找到block.py文件,具体位置是ultralytics/nn/modules/block.py,然后将ShuffleNetV2模块代码添加到block.py文件末尾位置。

ShuffleNetV2模块代码:

# ShuffleNetv2核心代码
# By CSDN 小哥谈
import torch
import torch.nn as nndef channel_shuffle(x, groups):batchsize, num_channels, height, width = x.data.size()channels_per_group = num_channels // groupsx = x.view(batchsize, groups, channels_per_group, height, width)x = torch.transpose(x, 1, 2).contiguous()x = x.view(batchsize, -1, height, width)return xclass CBRM(nn.Module):  # Conv BN ReLU Maxpool2ddef __init__(self, c1, c2):super(CBRM, self).__init__()self.conv = nn.Sequential(nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(c2),nn.ReLU(inplace=True),)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)def forward(self, x):return self.maxpool(self.conv(x))class Shuffle_Block(nn.Module):def __init__(self, ch_in, ch_out, stride):super(Shuffle_Block, self).__init__()if not (1 <= stride <= 2):raise ValueError('illegal stride value')self.stride = stridebranch_features = ch_out // 2assert (self.stride != 1) or (ch_in == branch_features << 1)if self.stride > 1:self.branch1 = nn.Sequential(self.depthwise_conv(ch_in, ch_in, kernel_size=3, stride=self.stride, padding=1),nn.BatchNorm2d(ch_in),nn.Conv2d(ch_in, branch_features, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)self.branch2 = nn.Sequential(nn.Conv2d(ch_in if (self.stride > 1) else branch_features,branch_features, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),nn.BatchNorm2d(branch_features),nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)@staticmethoddef depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)def forward(self, x):if self.stride == 1:x1, x2 = x.chunk(2, dim=1)out = torch.cat((x1, self.branch2(x2)), dim=1)else:out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)out = channel_shuffle(out, 2)return out

再然后,在block.py文件最上方下图所示位置加入CBRMShuffle_Block

🍀🍀步骤2:__init__.py文件修改

在源码中找到__init__.py文件,具体位置是ultralytics/nn/modules/__init__.py

修改1:加入CBRMShuffle_Block,具体如下图所示:

修改2:加入CBRMShuffle_Block,具体如下图所示:

🍀🍀步骤3:tasks.py文件修改

在源码中找到tasks.py文件,具体位置是ultralytics/nn/tasks.py

修改1:在下图所示位置导入类名CBRMShuffle_Block

修改2:找到parse_model函数(736行左右),在下图中所示位置添加如下代码。

 # -------ShuffleNetv2------------elif m in [CBRM, Shuffle_Block]:c1, c2 = ch[f], args[0]if c2 != nc:c2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, c2, *args[1:]]# --------------------------------

具体添加位置如下图所示:

🍀🍀步骤4:创建自定义yaml文件

在源码ultralytics/cfg/models/v8目录下创建yaml文件,并命名为:yolov8_ShuffleNetV2.yaml。具体如下图所示:

yolov8_ShuffleNetV2.yaml文件完整代码如下所示:

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [ -1, 1, CBRM, [ 32 ] ] # 0-P2/4- [ -1, 1, Shuffle_Block, [ 128, 2 ] ]  # 1-P3/8- [ -1, 3, Shuffle_Block, [ 128, 1 ] ]  # 2- [ -1, 1, Shuffle_Block, [ 256, 2 ] ]  # 3-P4/16- [ -1, 7, Shuffle_Block, [ 256, 1 ] ]  # 4- [ -1, 1, Shuffle_Block, [ 512, 2 ] ]  # 5-P5/32- [ -1, 3, Shuffle_Block, [ 512, 1 ] ]  # 6# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 3], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 9- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 2], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 12 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 15 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 6], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 18 (P5/32-large)- [[12, 15, 18], 1, Detect, [nc]]  # Detect(P3, P4, P5)
🍀🍀步骤5:新建train.py文件

在源码根目录下新建train.py文件,文件完整代码如下所示:

from ultralytics import YOLO# Load a model
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml')  # build a new model from YAML
model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml').load('yolov8n.pt')  # build from YAML and transfer weights# Train the model
model.train(data=r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\datasets\helmet.yaml', epochs=100, imgsz=640)

注意:一定要用绝对路径,以防发生报错。

🍀🍀步骤6:模型训练测试

train.py文件,点击“运行”,在作者自制的安全帽佩戴检测数据集上,模型可以正常训练。

模型训练过程: 

模型训练结果: 

 关于本次改进所使用的安全帽佩戴检测数据集,已上传至我的“资源”中,大家可免费下载。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2868503.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

用户视角的比特币和以太坊外围技术整理

1. 引言 要点&#xff1a; 比特币L2基本强调交易内容的隐蔽性&#xff0c;P2P交易&#xff08;尤其是支付&#xff09;成为主流&#xff0c;给用户带来一定负担&#xff08;闪电网络&#xff09;在以太坊 L2 中&#xff0c;一定程度上减少了交易的隐蔽性&#xff0c;主流是实…

【Linux进程状态】

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 前言 一、直接谈论Linux的进程状态 看看Linux内核源代码怎么说 1.1、R状态 -----> 进程运行的状态 1.2、S状态 -----> 休眠状态(进程在等待“资源”就绪) 1.3、T状…

突破编程_前端_JS编程实例(工具栏组件)

1 开发目标 工具栏组件旨在模拟常见的桌面软件工具栏&#xff0c;所以比较适用于 electron 的开发&#xff0c;该组件包含工具栏按钮、工具栏分割条和工具栏容器三个主要角色&#xff0c;并提供一系列接口和功能&#xff0c;以满足用户在不同场景下的需求&#xff1a; 点击工具…

Vue+SpringBoot打造音乐平台

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统展示 四、核心代码4.1 查询单首音乐4.2 新增音乐4.3 新增音乐订单4.4 查询音乐订单4.5 新增音乐收藏 五、免责说明 一、摘要 1.1 项目介绍 基于微信小程序JAVAVueSpringBootMySQL的音乐平台&#xff0c;包含了音乐…

tigramite教程(五)使用TIGRAMITE 进行自助聚合和链接置信度量化

使用TIGRAMITE 进行自助聚合和链接置信度量化 自助聚合&#xff08;Bagging&#xff09;和置信度估计例子数据生成模型基本的PCMCIBagged-PCMCI使用优化后的pc_alpha进行自举聚合使用优化的pc_alpha进行CMIknn的自举聚合 TIGRAMITE是一个用于时间序列分析的Python模块。它基于P…

如何批量获取公众号所有文章的阅读数点赞数和留言数导出excel?

如何批量获取公众号所有文章的阅读数点赞数和留言数导出excel&#xff1f;我写了个脚本批量抓取&#xff0c;导出的excel文章数据包含文章日期&#xff0c;文章标题&#xff0c;文章链接&#xff0c;文章简介&#xff0c;文章作者&#xff0c;文章封面图&#xff0c;是否原创&a…

《圣斗士星矢:纵横宇宙》(上)AI制作真人版大电影

《圣斗士星矢&#xff1a;纵横宇宙》&#xff08;上&#xff09;AI制作真人版大电影 平行宇宙&#xff0c;黑暗来袭&#xff0c;十二件黄金圣衣合体成为究极秘密武器&#xff01; 《圣斗士星矢&#xff1a;纵横宇宙》&#xff08;上&#xff09;电影开场&#xff0c;星矢等一众…

网络通信与网络协议

网络编程是指利用计算机网络实现程序之间通信的一种编程方式。在网络编程中&#xff0c;程序需要通过网络协议(如 TCP/IP)来进行通信&#xff0c;以实现不同计算机之间的数据传输和共享。在网络编程中&#xff0c;通常有三个基本要素 IP 地址:定位网络中某台计算机端口号port:定…

Nginx高级技术: 代理缓存配置

一、缓存说明 Nginx缓存&#xff0c;Nginx 提供了一个强大的反向代理和 HTTP 服务器功能&#xff0c;同时也是一个高效的缓存服务器。一般情况下系统用到的缓存有以下三种&#xff1a; 1、服务端缓存&#xff1a;缓存存在后端服务器&#xff0c;如 redis。 2、代理缓存&#…

6.【Linux】进程间通信(管道命名管道||简易进程池||简易客户端服务端通信)

介绍 进程间通信的方式 1.Linux原生支持的管道----匿名和命名管道 2.System V-----共享内存、消息队列、信号量 3.Posix------多线程、网路通信 进程间通信目的 数据传输&#xff1a;一个进程需要将它的数据发送给另一个进程 资源共享&#xff1a;多个进程之间共享同样的资源。…

UE4_调试工具_绘制调试球体

学习笔记&#xff0c;仅供参考&#xff01; 效果&#xff1a; 步骤&#xff1a; 睁开眼睛就是该变量在此蓝图的实例上可公开编辑。 勾选效果&#xff1a;

VUE3生命周期钩子

概念 每个 Vue 组件实例在创建时都需要经历一系列的初始化步骤&#xff0c;比如设置好数据侦听&#xff0c;编译模板&#xff0c;挂载实例到 DOM&#xff0c;以及在数据改变时更新 DOM。在此过程中&#xff0c;它也会运行被称为生命周期钩子的函数&#xff0c;让开发者有机会在…

文件包含例子

一、常见的文件包含函数 php中常见的文件包含函数有以下四种&#xff1a; include() require() include_once() require()_once() include与require基本是相同的&#xff0c;除了错误处理方面: include()&#xff0c;只生成警告&#xff08;E_WARNING&#xff09;&#x…

基于java+springboot+vue实现的自习室管理和预约系统(文末源码+Lw)23-177

摘 要 传统办法管理信息首先需要花费的时间比较多&#xff0c;其次数据出错率比较高&#xff0c;而且对错误的数据进行更改也比较困难&#xff0c;最后&#xff0c;检索数据费事费力。因此&#xff0c;在计算机上安装自习室管理和预约系统软件来发挥其高效地信息处理的作用&a…

Python爬取淘宝商品评价信息实战

文章目录 一、分析需要爬取的页面二、实现爬取商品评价信息的代码1、通过解析显示评价信息的元素获取商品评价信息2、通过mitmproxy代理进行流量抓包获取商品评价信息 三、附-完整代码 前期出了一个《爬取京东商品评价信息实战》的教程&#xff0c;最近又有网友提到要出一个爬淘…

案例分析篇13:系统分析与设计考点(2024年软考高级系统架构设计师冲刺知识点总结系列文章)

专栏系列文章推荐: 2024高级系统架构设计师备考资料(高频考点&真题&经验)https://blog.csdn.net/seeker1994/category_12593400.html 【历年案例分析真题考点汇总】与【专栏文章案例分析高频考点目录】(2024年软考高级系统架构设计师冲刺知识点总结-案例分析篇-…

幸福金龄会第二届《锦绣中华》广东省中老年协会携团队共襄盛举

近日&#xff0c;一场盛大的文化艺术盛宴——第二届《锦绣中华》中老年文旅文化艺术节在广东隆重举行。此次活动由幸福金龄会主办&#xff0c;吸引了广东省内各中老年协会的领导及文艺团队纷纷参与&#xff0c;共同为中华文化的传承与发展贡献力量。 广东省各中老年协会的会长、…

Java项目:56 ssm681基于Java的超市管理系统+jsp

作者主页&#xff1a;源码空间codegym 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 功能包括:商品分类&#xff0c;供货商管理&#xff0c;库存管理&#xff0c;销售统计&#xff0c;用户及角色管理&#xff0c;等等功能。项目采…

WpsOfficeExcel表格固定首行,点视图下的冻结窗格下的冻结首行

wps 和 微软 excel 表格固定首行的方法基本一样 WpsExcel表格固定首行,点视图下的冻结窗格下的冻结首行 WpsExcel表格固定首行,点视图下的冻结窗格下的冻结首行 Excel表格固定首行,点视图下的冻结窗格下的冻结首行 在Excel中固定首行&#xff0c;通常是通过“冻结窗格”功能…

复现文件上传漏洞

一、搭建upload-labs环境 将下载好的upload-labs的压缩包&#xff0c;将此压缩包解压到WWW中&#xff0c;并将名称修改为upload&#xff0c;同时也要在upload文件中建立一个upload的文件。 然后在浏览器网址栏输入&#xff1a;127.0.0.1/upload进入靶场。 第一关 选择上传文件…