手把手带你YOLOv5/v7 添加注意力机制,30多种模块分析①,SE模块,SK模块

目录

    • 一、注意力机制介绍
      • 1、什么是注意力机制?
      • 2、注意力机制的分类
      • 3、注意力机制的核心
    • 二、SE模块
      • 1、SE模块的原理
      • 2、代码实例
      • 3、实验结果
      • 4、应用示例
        • (1)在 `models/yolo.py` 文件中定义 `SEModule` 类,用于实现SE模块。
        • (2)在 `models/common.py` 文件中定义 `C3` 和 `CBL` 类时,引入 `SEModule` 类,并在需要的位置添加 SE 模块即可。
    • 三、SK模块
      • 1、SK模块的原理
      • 2、实验结果
      • 3、应用示例
        • (1)在models/common.py中定义SKConv模块:
        • (2)在models/yolo.py中使用SKConv模块。例如,在CSPBlock中使用SKConv的代码如下:
        • (3)在训练脚本train.py中设置使用SK模块。例如,可以将`--sk`参数设置为True启用SK模块:

大家好,我是哪吒。

🏆本文收录于,目标检测YOLO改进指南。

本专栏均为全网独家首发,内附代码,可直接使用,改进的方法均是2023年最近的模型、方法和注意力机制。每一篇都做了实验,并附有实验结果分析,模型对比。


在机器学习和自然语言处理领域,随着数据的不断增长和任务的复杂性提高,传统的模型在处理长序列或大型输入时面临一些困难。传统模型无法有效地区分每个输入的重要性,导致模型难以捕捉到与当前任务相关的关键信息。为了解决这个问题,注意力机制(Attention Mechanism)应运而生。

一、注意力机制介绍

1、什么是注意力机制?

注意力机制(Attention Mechanism)是一种在机器学习和自然语言处理领域中广泛应用的重要概念。它的出现解决了模型在处理长序列或大型输入时的困难,使得模型能够更加关注与当前任务相关的信息,从而提高模型的性能和效果。

本文将详细介绍注意力机制的原理、应用示例以及应用示例。

2、注意力机制的分类

类别描述
全局注意力机制(Global Attention)在计算注意力权重时,考虑输入序列中的所有位置或元素,适用于需要全局信息的任务。
局部注意力机制(Local Attention)在计算注意力权重时,只考虑输入序列中的局部区域或邻近元素,适用于需要关注局部信息的任务。
自注意力机制(Self Attention)在计算注意力权重时,根据输入序列内部的关系来决定每个位置的注意力权重,适用于序列中元素之间存在依赖关系的任务。
Bahdanau 注意力机制全局注意力机制的一种变体,通过引入可学习的对齐模型,对输入序列的每个位置计算注意力权重。
Luong 注意力机制全局注意力机制的另一种变体,通过引入不同的计算方式,对输入序列的每个位置计算注意力权重。
Transformer 注意力机制自注意力机制在Transformer模型中的具体实现,用于对输入序列中的元素进行关联建模和特征提取。

3、注意力机制的核心

注意力机制的核心思想是根据输入的上下文信息来动态地计算每个输入的权重。这个过程可以分为三个关键步骤:计算注意力权重、对输入进行加权和输出。首先,计算注意力权重是通过将输入与模型的当前状态进行比较,从而得到每个输入的注意力分数。这些注意力分数反映了每个输入对当前任务的重要性。对输入进行加权是将每个输入乘以其对应的注意力分数,从而根据其重要性对输入进行加权。最后,将加权后的输入进行求和或者拼接,得到最终的输出。注意力机制的关键之处在于它允许模型在不同的时间步或位置上关注不同的输入,从而捕捉到与任务相关的信息。

二、SE模块

1、SE模块的原理

SE(Squeeze-and-Excitation)模块是一种轻量级的注意力机制。它通过学习通道间的相互依赖关系,来增强有用的特征并抑制无用的特征。SE模块由两个部分组成:squeeze部分和excitation部分。squeeze部分使用全局平均池化层将每个通道的特征图压缩为一个标量,然后传递给excitation部分。excitation部分使用两个密集连接层,分别使用ReLU激活函数和sigmoid激活函数,来计算通道权重。最后,将每个通道的特征图与其对应的通道权重相乘,以产生具有更高加权响应的特征图。

在这里插入图片描述

SE模块的实现非常简单,只需要在卷积层后添加一个SE模块即可。

2、代码实例

下面是一个使用PyTorch实现的SE模块代码示例:

import torch.nn as nnclass SEModule(nn.Module):def __init__(self, in_channels, reduction_ratio=16):super(SEModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1,1))self.fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)self.fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)self.relu = nn.ReLU(inplace=True)self.sigmoid = nn.Sigmoid()def forward(self, x):batch_size, channels, _, _ = x.size()y = self.avg_pool(x).view(batch_size, channels)y = self.fc1(y)y = self.relu(y)y = self.fc2(y)y = self.sigmoid(y).view(batch_size, channels, 1, 1)return x * y.expand_as(x)

3、实验结果

同时在ImageNet数据集上训练基线架构及其SENet对应架构。SENet表现出改进的优化特性,并且产生了持续的性能增益,在整个训练过程中都能保持一致。

在这里插入图片描述

在ImageNet验证集上,最先进的CNN单模型使用224×224和320×320/299×299尺寸进行剪裁后的错误率(%)。

4、应用示例

在YOLOv5中,SE模块可以很方便地添加到主干网络的每个残差块中。

具体操作步骤如下:

(1)在 models/yolo.py 文件中定义 SEModule 类,用于实现SE模块。

import torch.nn as nnclass SEModule(nn.Module):def __init__(self, in_channels, reduction_ratio=16):super(SEModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))self.fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)self.relu = nn.ReLU(inplace=True)self.fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)self.sigmoid = nn.Sigmoid()def forward(self, x):batch_size, channels, height, width = x.size()y = self.avg_pool(x).view(batch_size, channels)y = self.fc1(y)y = self.relu(y)y = self.fc2(y)y = self.sigmoid(y).unsqueeze(2).unsqueeze(3)return x * y.expand_as(x)

(2)在 models/common.py 文件中定义 C3CBL 类时,引入 SEModule 类,并在需要的位置添加 SE 模块即可。

例如,在 C3 类的每个残差块的第一个卷积层后添加 SE 模块:

class C3(nn.Module):# ...def __init__(self, in_channels, out_channels, shortcut=True, e=0.5):# ...hidden_channels = int(out_channels * e)self.conv1 = CBL(in_channels, hidden_channels, 1, 1)self.conv2 = CBL(hidden_channels, out_channels, 3, 1)self.se = SEModule(out_channels)if shortcut and in_channels == out_channels:self.shortcut = nn.Identity()else:self.shortcut = CBL(in_channels, out_channels, 1, 1)# ...

通过以上操作,SE模块就被成功地添加到了YOLOv5中,可以在训练和测试时使用。

三、SK模块

1、SK模块的原理

在这里插入图片描述
SK(Selective Kernel)模块是一种基于注意力机制的模块,可以自适应地选择不同大小的卷积核来处理输入特征图。其主要思想是在局部感受野范围内引入一个自适应的通道注意力机制,以捕获特征之间的相关性,并且增强重要特征的表示。

具体来说,SK模块主要包含以下三个步骤:

  1. 全局池化:首先,将输入特征图通过全局平均池化层压缩为单个通道,得到一个全局的统计量,这个统计量可以代表整个特征图的信息。
  2. 分离变换:然后,将全局池化的输出通过两个全连接层,将通道数分别压缩为较小的数量,再将其恢复到原来的通道数,此时得到了一组代表特征显著程度的向量。
  3. 特征融合:最后,使用一组权重向量对不同大小的卷积核进行加权,将它们融合在一起处理输入特征图,从而增加网络的判别能力。

在第三步中,SK模块使用一组权重向量对不同大小的卷积核进行加权,具体地,假设输入特征图的大小为C×H×W,其中C、H和W分别代表通道数、高度和宽度。对于每个卷积核,SK模块都会生成一个权重向量,用于选择特征图中哪些通道最相关。在这里,SK模块通过一个自适应的通道注意力机制来生成权重向量,使得网络可以自动学习到不同通道之间的相关性以及它们对目标分类的贡献。

class SKConv(nn.Module):def __init__(self, c, r=16, stride=1, L=32):super(SKConv, self).__init__()# 计算分支数目Md = max(int(c/r), L)self.M = 2# 定义每个分支的卷积层和BN层self.conv = nn.ModuleList()for i in range(self.M):self.conv.append(nn.Sequential(nn.Conv2d(c, c, kernel_size=3+i*2, stride=stride, padding=1+i, groups=32, bias=False),nn.BatchNorm2d(c),nn.ReLU(inplace=True)))# 定义全连接层,用于生成权重向量self.fcs = nn.Sequential(nn.Linear(c, d),nn.ReLU(inplace=True),nn.Linear(d, c))# 定义softmax函数,用于归一化权重向量self.softmax = nn.Softmax(dim=1)def forward(self, x):batch_size = x.shape[0]feats = [conv(x).view(batch_size, -1, 1, 1) for conv in self.conv]  # 将每个分支的输出展开,并拼接在一起feats = torch.cat(feats, dim=2)feats = feats.mean(dim=2)  # 对各个通道的特征图取平均值feats = feats.view(batch_size, -1)  # 将所有通道的特征图组合成一个向量feats = self.fcs(feats).view(batch_size, -1, 1, 1)  # 用全连接层生成权重向量feats = self.softmax(feats)  # 对权重向量进行归一化feats = [feats[:, i].contiguous().view(batch_size, 1, x.size(2), x.size(3)) * conv(x) for i, conv in enumerate(self.conv)]  # 对输入特征图进行加权,然后再将它们融合在一起return sum(feats)  # 返回最终的输出结果

具体地,假设输入特征图的大小为C×H×W,其中C、H和W分别代表通道数、高度和宽度。SK模块的主要实现代码如下:

2、实验结果

在这里插入图片描述

在CIFAR数据集上的Top-1错误率(%,10次运行的平均值)。SENet-29和SKNet-29都基于ResNeXt-29,16×32d。

在这里插入图片描述

(a)和(b):两个随机抽样的图像的注意力结果,包含三种不同大小的目标(1.0x、1.5x和2.0x)。左上角:示例图像。左下角:SK 3 4中跨通道的5×5卷积核的注意力值。绘制的结果是16个连续通道的平均值,以方便查看。右侧:不同SK单元中5×5卷积核的注意力值减去3×3卷积核的注意力值。 ©:在ImageNet验证集的所有图像实例中平均的结果,也绘制了标准偏差。

3、应用示例

在YOLOv5中,SK模块主要应用于Backbone网络中的CSPDarknet53模块,以提高特征图的判别能力。

添加SK模块的步骤如下:

(1)在models/common.py中定义SKConv模块:

class SKConv(nn.Module):def __init__(self, c, r=16, stride=1, L=32):super(SKConv, self).__init__()# ...def forward(self, x):# ...return sum(feats)

(2)在models/yolo.py中使用SKConv模块。例如,在CSPBlock中使用SKConv的代码如下:

class CSPBlock(nn.Module):def __init__(self, in_channels, out_channels, num_blocks, shortcut=True, cardinality=1, bottleneck_width=64, sk_ratio=0.1):super(CSPBlock, self).__init__()# ...self.sk = SKConv(int(out_channels * sk_ratio), r=2, L=int(out_channels * sk_ratio)/2)  # 添加SKConv模块def forward(self, x):# ...x1 = self.cv1(x)x2 = self.cv2(x1)x2 = self.split_conv(x2)x3 = self.cv3(x1)x = torch.cat([x2, x3], dim=1)x = self.cv4(x)x = self.bn4(x)x = self.act(x)if self.sk_ratio > 0:  # 使用SKConv模块x = x + self.sk(x)if self.shortcut:x = x + inputsreturn self.act(x)

(3)在训练脚本train.py中设置使用SK模块。例如,可以将--sk参数设置为True启用SK模块:

python train.py --img 640 --batch 16 --epochs 50 --data coco.yaml --weights yolov5s.pt --cache --sk

参考论文:

  1. https://arxiv.org/pdf/1709.01507.pdf
  2. https://arxiv.org/pdf/1903.06586.pdf

在这里插入图片描述

🏆本文收录于,目标检测YOLO改进指南。

本专栏均为全网独家首发,🚀内附代码,可直接使用,改进的方法均是2023年最近的模型、方法和注意力机制。每一篇都做了实验,并附有实验结果分析,模型对比。

🏆哪吒多年工作总结:Java学习路线总结,搬砖工逆袭Java架构师。

🏆往期回顾:

1、YOLOv7如何提高目标检测的速度和精度,基于模型结构提高目标检测速度

2、YOLOv7如何提高目标检测的速度和精度,基于优化算法提高目标检测速度

3、YOLOv7如何提高目标检测的速度和精度,基于模型结构、数据增强提高目标检测速度

4、YOLOv5结合BiFPN,如何替换YOLOv5的Neck实现更强的检测能力?

5、YOLOv5结合BiFPN:BiFPN网络结构调整,BiFPN训练模型训练技巧

6、YOLOv7升级换代:EfficientNet骨干网络助力更精准目标检测

7、YOLOv5改进:引入DenseNet思想打造密集连接模块,彻底提升目标检测性能

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/353533.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

安装使用cuteFTP注意事项

花絮: 一直以来都使用红帽的共享文件来让windows和linux之间进行传输文件,今天头脑一发热,想使用windows下的cuteFTP软件来代替前面的方法。可谁想到,一是cuteFTP在网上根本找不到序列号,找了N久没找到,后…

CuteFTP安装

CuteFTP 9破解版,百度网盘链接:https://pan.baidu.com/s/16SDjxyQF2WtiPKpZHjcueQ 密码:xblr CuteFTP 9破解版是一款非常实用的商业FTP软件,也就是一个FTP客户端。可能很多人不知道FTP是什么,就是一个文件传输系统&…

Java 面试题:Spring,Spring MVC,Spring Boot 之间什么关系?

来,先和我看张图: Spring全家桶了为了解决不同场景的问题,逐渐演化出多套生态环框,如:Spring、SpringMVC、SpringBoot、SpringCloud。 Spring MVC和Spring Boot都属于Spring,Spring MVC是基于Spring的一个…

Springboot配置文件中的明文密码漏洞

目录 一、背景 二、本地修复测试 1、maven中引入jasypt 2、编写加密解密工具类 3、修改配置文件,增加秘钥 4、秘钥放在启动项 三、生产实现 1、升级打包代码 2、生产yml修改明文密码处 3、修改启动命令 一、背景 最近接收到网安的系统安全…

高德地图添加遮罩,实现圈出某个特定的地区

实现效果 一、准备 1、注册账号并申请Key 2、准备页面 <script type"text/javascript" src"https://webapi.amap.com/maps?v1.4.15&key您申请的key值"></script> <div id"container"></div>#container {widt…

vue 高德地图贴地点

效果图 引入高德api maps和local local 用的2.0.0的 其他版本可能会有不兼容问题 mounted(){//设置 地图this.map new AMap.Map(container, {mapStyle: , // 设置地图的自定义样式-深色zoom: 7.3, //级别center: [119.1, 36.32], //中心点坐标viewMode: 3D, // 地图模式resiz…

iOS 跳转到地图后导航(高德地图,百度地图,腾讯地图,苹果手机原生的地图)

1.现在info.plist里面如下图所示添加 2.在下图输入框中输入地名&#xff0c;然后点击前往目的地会出现如下图所示。&#xff08;如果你不知道地方名如何转化为经纬度请看我的另一篇博客&#xff1a;http://blog.csdn.net/chenyongkai1/article/details/51891135&#xff09; 3.…

高德导航免费,那他靠什么收入?

来源 &#xff5c;一口Linux 一位工作了12年的软件工程师说&#xff1a;当你打开导航时&#xff0c;不需要任何费用&#xff0c;还会给高德公司带来丰厚的收入。当时我不信&#xff0c;去查了相关资料后&#xff0c;才知道这个行业不简单。 出门外出&#xff0c;对路线不熟时&a…

C++编译一些常见的错误集锦

1、段错误&#xff08;Segmentation Fault&#xff09; &#xff08;1&#xff09;段错误&#xff08;Segmentation Fault&#xff09;是一种常见的计算机程序错误&#xff0c;通常指向程序试图访问的内存地址超出了程序可访问的内存范围&#xff0c;或者指针指向了无效的内存…

分布式光伏发电远程监控系统

分布式光伏发电远程监控系统 项目背景 新能源、可再生能源接入电网是智能电网建设的重要组成&#xff0c;也是能源互联网发展的基础。近年来&#xff0c;太阳能光伏发电技术快速发展&#xff0c;光伏发电并网对配电网的影响也不断加深。电网调度人员需要人工参与光伏发电站的发…

短视频矩阵源码技术开发

短视频矩阵是一种常见的视频编码标准&#xff0c;它通过将视频分成多个小块并对每个小块进行压缩来实现高效的视频传输。在本文中&#xff0c;我们将介绍短视频矩阵的原理和实现&#xff0c;并提供示例代码。 $where_time array(); // 时间 $where_time[] array(name>fbr…

运用正则表达式匹配QQ邮箱

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1.首先创建一个新文件夹 1.命名好名字后用Visual Studio Code打开 创建一个HTML文档用&#xff01;字符按下tab键后完成基本格式 创建一个input标签占位符写下请输入QQ邮箱再给他一个id方便调用 …

html5form表单提交到QQ邮箱,javaMai+Springl实现给QQ邮箱发邮件(带附件,html格式)...

以前的时候想着java发邮件很简单,因为当时使用的是outlook实现的,有兴趣的可以去看看之前的两篇博客文章,1.使用java底层实现邮件的发送(含测试,源码) 和 2.使用Spring实现邮件的发送(含测试,注释,源码) 就在今天,遇到的需求是给一个QQ邮箱发一份邮件,刚看到需求一看…

Linux向qq邮箱发送html表格以及遇到的问题

由于是实验&#xff0c;做的比较简陋&#xff0c;邮箱直接显示html界面&#xff0c;有诸多要求&#xff0c;本人对html不太擅长&#xff0c;详情可以参考http://www.ruanyifeng.com/blog/2013/06/html_email.html linux向qq发送邮件参考另一篇博客&#xff1a; https://blog.cs…

PGP加密解密QQ邮箱邮件

今天学习了PGP加密解密QQ邮箱邮件的方法&#xff0c;分享一下&#x1f601;&#x1f601;&#x1f601;&#x1f601; 涉及软件&#xff1a;PGP(PGP Desktop)、Outlook(office的) 加密&#xff1a; 第一步&#xff1a;安装PGP软件&#x1f602;&#x1f602;&#x1f602; 具体…

qq邮箱发html版式是乱的,为什么在Word里编辑的内容到QQ邮箱里发给别人是乱的,我用附件发的呀...

为什么在Word里编辑的内容到QQ邮箱里发给别人是乱的,我用附件发的呀以下文字资料是由(历史新知网www.lishixinzhi.com)小编为大家搜集整理后发布的内容,让我们赶快一起来看一下吧! 为什么在Word里编辑的内容到QQ邮箱里发给别人是乱的,我用附件发的呀, 为什么在Word里编辑好的…

ipad查看qq邮箱收件服务器,ipad怎么设置qq邮箱以便通过iPad来接收QQ邮箱收到的邮件...

大家可以通过下文来了解&#xff0c;小编将会演示ipad怎么设置qq邮箱&#xff0c;设置成功之后我们就能通过iPad来接收QQ邮箱收到的邮件&#xff0c;快来操作吧~ 下文是以iPhone设置QQ邮箱为例&#xff0c;和iPad步骤是一样的哦&#xff0c;首先进入“Mail”&#xff0c;点击“…

SpringCloud Gateway网关多路由配置访问404解决方案

文章目录 一、问题描述&#xff1a;SpringCloud GateWay Eureka访问出现404&#xff0c;Not Found二、解决方案:1、 配置 filters: - StripPrefix12、删除冲突依赖3、检查启动类4、检查配置文件 一、问题描述&#xff1a;SpringCloud GateWay Eureka访问出现404&#xff0c…

这才叫软件测试工程师,你那最多是混口饭吃罢了....

前些天和大学室友小聚了一下&#xff0c;喝酒喝大发了&#xff0c;谈天谈地谈人生理想&#xff0c;也谈到了我们各自的发展&#xff0c;感触颇多。曾经找工作我迷茫过、徘徊不&#xff0c;毕业那会我屡屡面试失败&#xff0c;处处碰壁&#xff1b;工作两年后我一度想要升职加薪…

QQ截图快捷键设置

1、打开qq找到主菜单找到设置 2、找到热键----------点击设置热键 3、双击捕捉屏幕进行设置快捷键 然后就OK了。