昇思25天学习打卡营第10天|ShuffleNet图像分类

ShuffleNet网络结构

ShuffleNet是一种专为移动设备设计的、计算效率极高的卷积神经网络(CNN)架构。其网络结构的设计主要围绕减少计算复杂度和提高模型效率展开,通过引入逐点分组卷积(Pointwise Group Convolution)和通道洗牌(Channel Shuffle)两种新技术,实现了在保持精度的同时大幅降低计算成本。

逐点分组卷积(Pointwise Group Convolution):

逐点分组卷积是ShuffleNet中用于减少1x1卷积计算复杂度的方法。它将输入特征图的通道分成多个组,每个组内的通道独立进行1x1卷积,从而显著降低了计算量。
在这里插入图片描述

然而,这种方法可能导致通道间的信息无法充分交流,影响模型的表达能力。可能会降低网络的特征提取能力

通道洗牌(Channel Shuffle):

为了解决逐点分组卷积带来的通道间信息交流不足的问题,ShuffleNet引入了通道洗牌操作。通过均匀地打乱不同分组中的通道,使得每个分组都能获得来自其他分组的信息,从而增强模型的特征提取能力。

在这里插入图片描述
在这里插入图片描述

ShuffleNet对ResNet中的Bottleneck结构进行由(a)到(b), ©的更改:

  1. 将开始和最后的 1×1卷积模块(降维、升维)改成Point Wise Group Convolution

  2. 为了进行不同通道的信息交流,再降维之后进行Channel Shuffle

  3. 降采样模块中, 3×3 Depth Wise Convolution的步长设置为2,长宽降为原来的一般,因此shortcut中采用步长为23×3平均池化,并把相加改成拼接。
    在这里插入图片描述
    ShuffleV1Block

class ShuffleV1Block(nn.Cell):def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):super(ShuffleV1Block, self).__init__()self.stride = stridepad = ksize // 2self.group = groupif stride == 2:outputs = oup - inpelse:outputs = oupself.relu = nn.ReLU()branch_main_1 = [GroupConv(in_channels=inp, out_channels=mid_channels,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=1 if first_group else group),nn.BatchNorm2d(mid_channels),nn.ReLU(),]branch_main_2 = [nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,pad_mode='pad', padding=pad, group=mid_channels,weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(mid_channels),GroupConv(in_channels=mid_channels, out_channels=outputs,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=group),nn.BatchNorm2d(outputs),]self.branch_main_1 = nn.SequentialCell(branch_main_1)self.branch_main_2 = nn.SequentialCell(branch_main_2)if stride == 2:self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')def construct(self, old_x):left = old_xright = old_xout = old_xright = self.branch_main_1(right)if self.group > 1:right = self.channel_shuffle(right)right = self.branch_main_2(right)if self.stride == 1:out = self.relu(left + right)elif self.stride == 2:left = self.branch_proj(left)out = ops.cat((left, right), 1)out = self.relu(out)return outdef channel_shuffle(self, x):batchsize, num_channels, height, width = ops.shape(x)group_channels = num_channels // self.groupx = ops.reshape(x, (batchsize, group_channels, self.group, height, width))x = ops.transpose(x, (0, 2, 1, 3, 4))x = ops.reshape(x, (batchsize, num_channels, height, width))return x

ShuffleNet的基本单元是在残差单元(residual block)的基础上改进而成的,具体结构如下:

1x1分组卷积:首先,输入特征图通过一个1x1的分组卷积进行降维,减少通道数。
通道洗牌:紧接着,对分组卷积的输出进行通道洗牌操作,以实现不同分组之间的信息交流。
3x3深度可分离卷积:然后,使用3x3的深度可分离卷积(depthwise separable convolution)进行特征提取。这里的3x3卷积是瓶颈层(bottleneck),用于降低计算量。
1x1分组卷积(可选):最后,根据需要,可以通过另一个1x1的分组卷积将通道数恢复到与输入相同或更大的数量。
短路连接:在基本单元中,还包含短路连接(shortcut),用于将输入特征图直接加到输出特征图上,以保留原始信息并帮助梯度回传。
在这里插入图片描述

ShuffleNet网络结构如上图所示,以输入图像 224×224 ,组数3(g = 3)为例,首先通过数量24,卷积核大小为 3×3stride2的卷积层,输出特征图大小为 112×112 ,channel为24;然后通过stride为2的最大池化层,输出特征图大小为 56×56channel数不变;再堆叠3个ShuffleNet模块(Stage2, Stage3, Stage4),三个模块分别重复4次、8次、4次,其中每个模块开始先经过一次下采样模块(上图©),使特征图长宽减半,channel翻倍(Stage2的下采样模块除外,将channel数从24变为240);随后经过全局平均池化,输出大小为 1×1×960 ,再经过全连接层softmax,得到分类概率

ShuffleNetV1

class ShuffleNetV1(nn.Cell):def __init__(self, n_class=1000, model_size='2.0x', group=3):super(ShuffleNetV1, self).__init__()print('model size is ', model_size)self.stage_repeats = [4, 8, 4]self.model_size = model_sizeif group == 3:if model_size == '0.5x':self.stage_out_channels = [-1, 12, 120, 240, 480]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 240, 480, 960]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 360, 720, 1440]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 480, 960, 1920]else:raise NotImplementedErrorelif group == 8:if model_size == '0.5x':self.stage_out_channels = [-1, 16, 192, 384, 768]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 384, 768, 1536]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 576, 1152, 2304]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 768, 1536, 3072]else:raise NotImplementedErrorinput_channel = self.stage_out_channels[1]self.first_conv = nn.SequentialCell(nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(input_channel),nn.ReLU(),)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')features = []for idxstage in range(len(self.stage_repeats)):numrepeat = self.stage_repeats[idxstage]output_channel = self.stage_out_channels[idxstage + 2]for i in range(numrepeat):stride = 2 if i == 0 else 1first_group = idxstage == 0 and i == 0features.append(ShuffleV1Block(input_channel, output_channel,group=group, first_group=first_group,mid_channels=output_channel // 4, ksize=3, stride=stride))input_channel = output_channelself.features = nn.SequentialCell(features)self.globalpool = nn.AvgPool2d(7)self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)def construct(self, x):x = self.first_conv(x)x = self.maxpool(x)x = self.features(x)x = self.globalpool(x)x = ops.reshape(x, (-1, self.stage_out_channels[-1]))x = self.classifier(x)return x

设置model_size="2.0x",定义模型的复杂度。

 net = ShuffleNetV1(model_size="2.0x", n_class=10)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3223918.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

python reload找不到怎么办

Python 3.0 把 reload 内置函数移到了 imp 标准库模块中。它仍然像以前一样重载文件,但是,必须导入它才能使用。 方法一: from imp import reload reload(module) 方法二: import imp imp.reload(module)

4:表单和通用视图

表单和通用视图 1、编写一个简单的表单&#xff08;1&#xff09;更新polls/detail.html文件 使其包含一个html < form > 元素&#xff08;2&#xff09;创建一个Django视图来处理提交的数据&#xff08;3&#xff09;当有人对 Question 进行投票后&#xff0c;vote()视图…

入门PHP就来我这(高级)19 ~ 捕获sql错误

有胆量你就来跟着路老师卷起来&#xff01; -- 纯干货&#xff0c;技术知识分享 路老师给大家分享PHP语言的知识了&#xff0c;旨在想让大家入门PHP&#xff0c;并深入了解PHP语言。 接着上篇我们来看下sql错误的捕获模式。 1 PDO中捕获SQL语句中的错误 在PDO中有3种方法可以捕…

企业化运维(7)_Zabbix企业级监控平台

官网&#xff1a;Zabbix :: The Enterprise-Class Open Source Network Monitoring Solution ###1.Zabbix部署### &#xff08;1&#xff09;zabbix安装 安装源 修改安装路径为清华镜像 [rootserver1 zabbix]# cd /etc/yum.repos.d/ [rootserver1 yum.repos.d]# vim zabbix.r…

嵌入式c语言——指针加修饰符

指针变量可以用修饰符来修饰

软件开发面试题(C#语言,.NET框架)

1. 解释什么是委托&#xff08;Delegate&#xff09;&#xff0c;并举例说明它在C#中的用法。 委托是一种引用类型&#xff0c;它可以用于封装一个或多个方法。委托对象可以像方法一样调用&#xff0c;甚至可以用于创建事件处理程序。委托是C#中实现事件和回调函数的重要机制。…

Proteus + Keil单片机仿真教程(五)多位LED数码管的静态显示

Proteus + Keil单片机仿真教程(五)多位LED数码管 上一章节讲解了单个数码管的静态和动态显示,这一章节将对多个数码管的静态显示进行学习,本章节主要难点: 1.锁存器的理解和使用; 2.多个数码管的接线封装方式; 3.Proteus 快速接头的使用。 第一个多位数码管示例 元件…

windows JDK11 与JDK1.8自动切换,以及切换后失效的问题

1.windows安装不同环境的jdk 2.切换jdk 3.切换失败 原因&#xff1a;这是因为当我们安装并配置好JDK11之后它会自动生成一个环境变量&#xff08;此变量我们看不到&#xff09;&#xff0c;此环境变量优先级较高&#xff0c;导致我们在切换回JDK8后系统会先读取到JDK11生成的…

vue3中使用provide跨层传值(方法)

1.使用provide inject 跨层实现 祖父组件&#xff1a; provide有两个参数&#xff0c;第一个是我们传递的key&#xff0c;第二个就是value了 孙子组件&#xff1a; const dataList inject(getDataList1)//使用inject接收 const dataList1 dataList.getDataList 页面中使…

Nuxt框架中内置组件详解及使用指南(四)

title: Nuxt框架中内置组件详解及使用指南&#xff08;四&#xff09; date: 2024/7/9 updated: 2024/7/9 author: cmdragon excerpt: 摘要&#xff1a;本文详细介绍了Nuxt 3框架中的两个内置组件&#xff1a;和的使用方法与示例。用于捕获并处理客户端错误&#xff0c;提供…

jvm 06 对象内存结构,指针压缩,调优

01 内存布局 mark word 32bit 4B 64bit 8B 类型指针 klass pointer 开启指针压缩 4B 关闭指针压缩 8B 数组长度 4B 没有这个区域 实例数据 bool 1B 1 true&#xff0c;0 false #define TRUE 1 byte 1B char 2B 1B int 4B float 4B long 8B double 8B 引用类型 开启指针压缩 4B …

基于STM32的气压检测报警proteus仿真设计(仿真+程序+设计报告+讲解视频)

基于STM32的气压检测报警proteus仿真设计 1.主要功能2.仿真3. 程序4. 设计报告5. 资料清单&下载链接资料下载链接&#xff1a; 基于STM32的气压检测报警proteus仿真设计(仿真程序设计报告讲解视频&#xff09; 仿真图proteus 8.9 程序编译器&#xff1a;keil 5 编程语言…

Python实现吃豆人游戏详解(内附完整代码)

一、吃豆人游戏背景 吃豆人是一款由Namco公司在1980年推出的经典街机游戏。游戏的主角是一个黄色的小圆点&#xff0c;它必须在迷宫中吃掉所有的点数&#xff0c;同时避免被四处游荡的幽灵捉到。如果玩家能够吃掉所有的点数&#xff0c;并且成功避开幽灵&#xff0c;就可以进入…

【多线程】wait()和notify()

&#x1f970;&#x1f970;&#x1f970;来都来了&#xff0c;不妨点个关注叭&#xff01; &#x1f449;博客主页&#xff1a;欢迎各位大佬!&#x1f448; 文章目录 1. 为什么需要wait()方法和notify()方法&#xff1f;2. wait()方法2.1 wait()方法的作用2.2 wait()做的事情2…

一文带你彻底搞懂什么是责任链模式!!

文章目录 什么是责任链模式&#xff1f;详细示例SpingMVC 中的责任链模式使用总结 什么是责任链模式&#xff1f; 在我们日常生活中&#xff0c;经常会出现一种场景&#xff1a;一个请求需要经过多个对象的处理才能得到最终的结果。比如&#xff0c;一个请假申请&#xff0c;需…

保姆级教程:Linux (Ubuntu) 部署流光卡片开源 API

流光卡片 API 开源地址 Github&#xff1a;https://github.com/ygh3279799773/streamer-card 流光卡片 API 开源地址 Gitee&#xff1a;https://gitee.com/y-gh/streamer-card 流光卡片在线使用地址&#xff1a;https://fireflycard.shushiai.com/ 等等&#xff0c;你说你不…

如何在Excel中对一个或多个条件求和?

在Excel中&#xff0c;基于一个或多个条件的求和值是我们大多数人的常见任务&#xff0c;SUMIF函数可以帮助我们根据一个条件快速求和&#xff0c;而SUMIFS函数可以帮助我们对多个条件求和。 本文&#xff0c;我将描述如何在Excel中对一个或多个条件求和&#xff1f; 在Excel中…

2020 ICPC Shanghai Site B. Mine Sweeper II 题解 构造 鸽巢原理

Mine Sweeper II 题目描述 A mine-sweeper map X X X can be expressed as an n m n\times m nm grid. Each cell of the grid is either a mine cell or a non-mine cell. A mine cell has no number on it. Each non-mine cell has a number representing the number of…

gif压缩大小但不改变画质的最佳方法,7个gif压缩免费工具别错过!

你会不会也碰到过当你需要在自媒体平台上上传gif文件时&#xff0c;你会发现网页端最大限制为15MB&#xff0c;而手机端最大限制为5MB。那么如何在不不改变画质的同时压缩gif大小呢&#xff1f;如今&#xff0c;由于其特殊的动画以及快速传输的特点&#xff0c;gif文件已经成为…

原创作品—数据可视化大屏

设计数据可视化大屏时&#xff0c;用户体验方面需注重以下几点&#xff1a;首先&#xff0c;确保大屏信息层次分明&#xff0c;主要数据突出显示&#xff0c;次要信息适当弱化&#xff0c;帮助用户快速捕捉关键信息。其次&#xff0c;设计应直观易懂&#xff0c;避免复杂难懂的…