模型剪枝中有哪些经验|mobile-yolov5-pruning-distillation项目中剪枝知识分析

项目地址:https://github.com/Syencil/mobile-yolov5-pruning-distillation
项目时间:2022年
mobile-yolov5-pruning-distillation是一个以yolov5改进为主的开源项目,主要包含3中改进方向:更改backbone、模型剪枝、知识蒸馏。这里主要研究其模型剪枝部分,关于知识蒸馏后续在进行分析。关于更改非coco训出的backbone(使用moblienet替换),可以发现存在相助的精度下降,这表明imagenet域训练处的权重迁移到目标检测领域不如二次迁移的模型(先imagenet,再coco训练)。
从项目代码分析中,主要学习到bn层的稀疏化训练是如何实现的(章节2.1)、如果按照权重大小对模型进行剪枝(章节2.3)
在这里插入图片描述

yolov5s在640x640分辨率下的计算量和参数量分别为8.39G和7.07M。在速度上仍然有提升空间,通过替换backbone(mobilenetv2),通道剪枝对模型进行压缩。 利用蒸馏提升小模型的精度。项目以工程化为基础,主要是模型端的优化。实现了常用的剪枝和蒸馏算法,并对其做了一个简单的介绍和评估。通过在voc上实验来验证各个方法的有效性。 最后将工程可用模型转换成对应部署版本。

1、剪枝操作

1.1 baseline

数据集采用Pascal VOC,trainset = train2007+train2012+val2007+val2012,testset = test2007,Baseline采用mobile-yolo(imagenet预训练),如果没有特别说明,第一个模块采用Focus 如果未经特殊说明则均为使用默认参数,batchsize=24,epoch=50,train_size = 640,test_size = 640,conf_thres=0.001,iou_thres=0.6,mAP均为50

PS. 由于资源有限,此项目只训练50个epoch,实际上可以通过调整学习率和迭代次数进一步提高mAP。但是可以通过控制相同的超参数来进行实验对比,所以并不影响最终结果。

在这里插入图片描述

baseline由4个部分组成:yolov5s,官方提供的coco权重在voc上进行微调所以不具备可比性,但是可以作为蒸馏指导模型;mobilev2-yolo5s和mobilev2-yolo5l均是只更改了对应的backbone;mobilev2-yolo3则是用的yolo3head,结构同keras-YOLOv3-mobilenet 基本一致(keras的是mobilev1,参数量和计算量更大),此处作为参照物。

1.2 模型剪枝

选取mobilev2-yolo5s作为剪枝的基础模型。以以下策略为基础:
1、输出层不动,统计其他所有BN层的weight分布
2、根据稀疏率决定剪枝阈值
3、开始剪枝,如果当前层所有值均小于阈值则保留最大的一个通道(保证结构不被破坏)
在这里插入图片描述
先从头训练一个baseline,以及训练一个对bn中gamma参数加入L1正则化的网络。稀疏参数为sl=6e-4。结果比baseline掉了3个点。

剪枝策略按照论文中的做法给定一个稀疏率,统计所有参与剪枝层的bn参数l1值并进行排序,依据稀疏率确定阈值。

将所有小于阈值的层全部减掉,如果有依赖则将依赖的对应部分也剪掉。如果一层中所有的层都需要被移除,那么就保留最大的一层通道(保证网络结构)。

根据以下图表数据,可以发现剪枝了约25%的参数(主要看pruning2),模型预计加速1.25x,map下降了0.2,可能是训练不充分导致的。
在这里插入图片描述

1.3 与知识蒸馏的对比

作者在后面也分享了知识蒸馏的结果,可以发现蒸馏后的mobilenet比原始的模型高2哥百分点,map达74%,相比于对大模型进行剪枝,或许知识蒸馏更能保持精度。
对蒸馏后的模型进行剪枝或许能再度提升模型的运行速度。
在这里插入图片描述

知识蒸馏信息如下:

将特征图和输出层一起作为蒸馏指导。对于T和S中间特征图输出维度不匹配的问题,采用在S网络输出接一个Converter,将其升维到T网络匹配。 Converter由conv+bn+relu6组成,T网络输出单独接一个relu6,保证激活函数相同。(上个commit版本出现了一个bug,导致精度没变其实是不对的,现已修正) output层参数为1.0,feature参数为0.5。mAP0.663甚至比baseline都要低。feature distillation居然让模型掉点了,怀疑是feature权重太大,降到0.1667,mAP可以提升到0.68,还是低于baseline。 继续下降到0.05,mAP可以回到baseline的水平,不过在训练末期mAP还在上升,loss还在下降。最后尝试训练100个epoch,mAP才回到74。 实际上还尝试过各种变形和各种参数,但是感觉效果仍然不好。

2、其中关键代码

2.1 bn层稀疏化训练

bn层稀疏化训练的原理可以查看 https://blog.csdn.net/a486259/article/details/140594912 中的1.1 BN层剪枝的理解基础。

这里主要分析如何实现稀疏化训练,相关代码在train.py中,相关关键代码为sl。进行检索,发现多处关键代码。
关键一: bn稀疏化设置
通过以下代码,为bn gamma值稀疏化,设置稀疏率sl与weight_decay(l2正则化)值相同的缩放比例。同时,将bn参数单独加入到pg0组中,设置单独的学习率;将非bn参数加入到pg1中,设置单独的权重衰减(l2正则化);将所有的bias参数都不进行正则化。

    hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decayif opt.sl > 0:hyp['sl'] *= batch_size * accumulate / nbspg0, pg1, pg2 = [], [], []  # optimizer parameter groupsfor k, v in model.named_parameters():if v.requires_grad:if '.bias' in k:pg2.append(v)  # biaseselif '.weight' in k and '.bn' not in k:pg1.append(v)  # apply weight decayelse:pg0.append(v)  # all elseoptimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else \optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']})  # add pg1 with weight_decayoptimizer.add_param_group({'params': pg2})  # add pg2 (biases)print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))del pg0, pg1, pg2

关键二、稀疏化剪枝loss
对应函数为compute_pruning_loss
在这里插入图片描述
对函数深入分析可以发现,就是计算正则化loss
在这里插入图片描述

以上代码比较繁琐,但其主要目的是对非bn层参数进行l2正则化,对bn层参数进行l1正则化。l2正则化使得conv、linear层参数趋近于0,l1正则化使大量的bn 层gamma参数值为0。这样使得稀疏化训练后的模型剪枝后精度下降很低。这样的约束会导致模型精度存在轻微下降,但在剪枝后重新训练模型移除约束后,基本上会使精度恢复到约束前。

2.2 可剪枝层获取

通过以下代码获取了模型中所有的bn层

    if opt.sl > 0:print("Sparse Learning Model!")print("===> Sparse learning rate is ", opt.sl)ignore_idx = [230, 260, 290]prunable_modules = []prunable_module_type = (nn.BatchNorm2d, )for i, m in enumerate(model.modules()):if i in ignore_idx:continueif isinstance(m, prunable_module_type):prunable_modules.append(m)

2.3 基于bn层的剪枝

相关代码在pruning.py中,主要是channel_prune。
1、先基于bn_analyze对bn层的gamma参数进行统计(一个bn层中有多个gamma参数,bn输出64个channel,就有64个gamma参数);
2、然后遍历可剪枝模块(所有的bn层),根据阈值获取每一个bn层中需剪枝的index
3、基于DG.get_pruning_plan直接对模块中指定index进行剪枝

        pos = np.array([i for i in range(len(L1_norm))])pruned_idx_mask = L1_norm < thresprun_index = pos[pruned_idx_mask].tolist()if len(prun_index) == len(L1_norm):del prun_index[np.argmax(L1_norm)]plan = DG.get_pruning_plan(layer_to_prune, prune_fn, prun_index)plan.exec()

bn_analyze函数统计参数过程中绘制的直方图如下
在这里插入图片描述

完整代码如下:

def bn_analyze(prunable_modules, save_path=None):bn_val = []max_val = []for layer_to_prune in prunable_modules:# select a layerweight = layer_to_prune.weight.data.detach().cpu().numpy()max_val.append(max(weight))bn_val.extend(weight)bn_val = np.abs(bn_val)max_val = np.abs(max_val)bn_val = sorted(bn_val)max_val = sorted(max_val)plt.hist(bn_val, bins=101, align="mid", log=True, range=(0, 1.0))if save_path is not None:if os.path.isfile(save_path):os.remove(save_path)plt.savefig(save_path)return bn_val, max_valdef channel_prune(ori_model, example_inputs, output_transform, pruned_prob=0.3, thres=None):model = copy.deepcopy(ori_model)model.cpu().eval()prunable_module_type = (nn.BatchNorm2d)ignore_idx = [230, 260, 290]prunable_modules = []for i, m in enumerate(model.modules()):if i in ignore_idx:continueif isinstance(m, prunable_module_type):prunable_modules.append(m)ori_size = tp.utils.count_params(model)DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs,output_transform=output_transform)bn_val, max_val = bn_analyze(prunable_modules, "render_img/before_pruning.jpg")if thres is None:thres_pos = int(pruned_prob * len(bn_val))thres_pos = min(thres_pos, len(bn_val)-1)thres_pos = max(thres_pos, 0)thres = bn_val[thres_pos]print("Min val is %f, Max val is %f, Thres is %f" % (bn_val[0], bn_val[-1], thres))for layer_to_prune in prunable_modules:# select a layerweight = layer_to_prune.weight.data.detach().cpu().numpy()if isinstance(layer_to_prune, nn.Conv2d):if layer_to_prune.groups > 1:prune_fn = tp.prune_group_convelse:prune_fn = tp.prune_convL1_norm = np.sum(np.abs(weight), axis=(1, 2, 3))elif isinstance(layer_to_prune, nn.BatchNorm2d):prune_fn = tp.prune_batchnormL1_norm = np.abs(weight)pos = np.array([i for i in range(len(L1_norm))])pruned_idx_mask = L1_norm < thresprun_index = pos[pruned_idx_mask].tolist()if len(prun_index) == len(L1_norm):del prun_index[np.argmax(L1_norm)]plan = DG.get_pruning_plan(layer_to_prune, prune_fn, prun_index)plan.exec()bn_analyze(prunable_modules, "render_img/after_pruning.jpg")with torch.no_grad():out = model(example_inputs)if output_transform:out = output_transform(out)print("  Params: %s => %s" % (ori_size, tp.utils.count_params(model)))if isinstance(out, (list, tuple)):for o in out:print("  Output: ", o.shape)else:print("  Output: ", out.shape)print("------------------------------------------------------\n")return model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3266082.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

路由表与IP数据报的转发

前言&#xff1a;本博客仅作记录学习使用&#xff0c;部分图片出自网络&#xff0c;如有侵犯您的权益&#xff0c;请联系删除 一、相关知识 1、路由类型 路由表中有3类路由&#xff1a;直连路由、静态路由、动态路由 直连路由&#xff1a;一般指去往路由器接口直接连接网络的…

JAW:一款针对客户端JavaScript的图形化安全分析框架

关于JAW JAW是一款针对客户端JavaScript的图形化安全分析框架&#xff0c;该工具基于esprima解析器和EsTree SpiderMonkey Spec实现其功能&#xff0c;广大研究人员可以使用该工具分析Web应用程序和基于JavaScript的客户端程序的安全性。 工具特性 1、动态可扩展的框架&#x…

LeetCode 2844.生成特殊数字的最少操作(哈希表 + 贪心)

给你一个下标从 0 开始的字符串 num &#xff0c;表示一个非负整数。 在一次操作中&#xff0c;您可以选择 num 的任意一位数字并将其删除。请注意&#xff0c;如果你删除 num 中的所有数字&#xff0c;则 num 变为 0。 返回最少需要多少次操作可以使 num 变成特殊数字。 如…

vue接入google map自定义marker教程

需求背景 由于客户需求&#xff0c;原来系统接入的高德地图&#xff0c;他们不接受&#xff0c;需要换成google地图。然后就各种百度&#xff0c;各种Google&#xff0c;却不能实现。----无语&#xff0c;就连google地图官方的api也是一坨S-H-I。所以才出现这篇文章。 google地…

CSS(七)——CSS 列表和CSS Table(表格)

目录 CSS 列表 列表 作为列表项标记的图像 列表 - 简写属性 移除默认设置 所有的CSS列表属性 CSS 表格 表格边框 折叠边框&#xff08;border-collapse&#xff09; 表格宽度和高度 表格文字对齐 表格填充 表格颜色 CSS 列表 CSS 列表属性作用如下&#xff1a; 设…

Hello SLAM(在Linux中实现第一个C++程序)

首先需要安装vim编辑器&#xff0c;输入命令 sudo apt install vim 在Ubuntu上安装好vim编辑器后&#xff0c;创建路径&#xff08;/home/slambook/ch2&#xff09;&#xff0c;在该路径下创建一个cpp文档&#xff08;touch hello.c&#xff09;&#xff0c;通过vim编辑器进行…

【OpenCV C++20 学习笔记】图片处理基础

OpenCV C20 图片处理基础 VS 2022 C20 标准库导入的问题头文件包含以及命名空间声明main函数读取图片读取检查显式图片写入图片 完整代码bug VS 2022 C20 标准库导入的问题 VS还没有完全兼容C20。C20的import语句不一定能正确导入标准库&#xff0c;所以必须要新建一个头文件专…

【全面介绍Python多线程】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步! 🦇目录 1. 🦇前言2. 🦇threading 模块的基本用法3. 🦇Thre…

.NET 相关概念

.NET 和 .NET SDK .NET 介绍 .NET 是一个由 Microsoft 开发和维护的广泛用于构建各种类型应用程序的开发框架。它是一个跨平台、跨语言的开发平台&#xff0c;提供了丰富的类库、API和开发工具&#xff0c;支持开发者使用多种编程语言&#xff08;如C#、VB.NET、F#等&#xf…

[C++进阶]多态的概念、定义与实现

多态&#xff0c;顾名思义&#xff0c;即多种形态。具体来说&#xff0c;就是不同对象执行同一行为而产生不同的结果。 一、多态的概念 多态的概念&#xff1a;通俗来说&#xff0c;就是多种形态&#xff0c;具体点就是去完成某个行为&#xff0c;当不同的对象去完成时会产生…

机器学习 | 回归算法原理——随机梯度下降法

Hi&#xff0c;大家好&#xff0c;我是半亩花海。接着上次的多重回归继续更新《白话机器学习的数学》这本书的学习笔记&#xff0c;在此分享随机梯度下降法这一回归算法原理。本章的回归算法原理还是基于《基于广告费预测点击量》项目&#xff0c;欢迎大家交流学习&#xff01;…

NET8部署Kestrel服务HTTPS深入解读TLS协议之Certificate证书

Certificate证书 Certificate称为数字证书。数字证书是一种证明身份的电子凭证&#xff0c;它包含一个公钥和一些身份信息&#xff0c;用于验证数字签名和加密通信。数字证书在网络通信、电子签名、认证授权等场景中都有广泛应用。其特征如下&#xff1a; 由权威机构颁发&…

没有最好,只有适合:根据实际情况务实的设置软件研发环境

在容器化&#xff0c;开源软件和云服务驱动的软件开发时代&#xff0c;持续集成的理念已经深入人心&#xff0c;无论我们在哪一家公司&#xff0c;只要是开发和长期维护一款互联网产品&#xff0c;在从开发到上线的过程中&#xff0c;团队都会有一套研发环境&#xff0c;处于不…

【Docker】Windows11环境下的安装

前置依赖环境配置 确保虚拟化开启 搜索栏直接搜索如下功能 勾选下面两个选项&#xff0c;确定 重启电脑&#xff0c;以管理员身份打开PowerShell wsl --status wsl --update打开微软应用商店选择一个Ubuntu版本下载并打开 输入一个用户名和密码 然后就可以在Windows下使…

JavaWeb笔记_JSTL标签库JavaEE三层架构案例

一.JSTL标签库 1.1 JSTL概述 JSTL(jsp standard tag library):JSP标准标签库,它是针对EL表达式一个扩展,通过JSTL标签库与EL表达式结合可以完成更强大的功能 JSTL它是一种标签语言,JSTL不是JSP内置标签 JSTL标签库主要包含: ****核心标签 格式化标签 …

eqmx上读取数据处理以后添加到数据库中

目录 定义一些静态变量 定时器事件的处理器 订阅数据的执行器 处理json格式数据和将处理好的数据添加到数据库中 要求和最终效果 总结一下 定义一些静态变量 // 在这里都定义成全局的 一般都定义成静态的private static MqttClient mqttClient; // mqtt客户端 private s…

大模型llama结构技术点分享;transformer模型常见知识点nlp面经

1、大模型llama3技术点 参考&#xff1a;https://www.zhihu.com/question/662354435/answer/3572364267 Llama1-3&#xff0c;数据tokens从1-2T到15T;使用了MHA&#xff08;GQA缓存&#xff09;&#xff1b;上下文长度从2-4-8K&#xff1b;应用了强化学习对其。 1、pretraini…

WINUI——Microsoft.UI.Xaml.Markup.XamlParseException:“无法找到与此错误代码关联的文本。

开发环境 VS2022 .net core6 问题现象 在Canvas内的子控件要绑定Canvas的兄弟控件的一个属性&#xff0c;在运行时出现了下述报错。 可能原因 在 WinUI&#xff08;特别是用于 UWP 或 Windows App SDK 的版本&#xff09;中&#xff0c;如果你尝试在 XAML 中将 Canvas 内的…

IEC104转MQTT网关轻松将IEC104设备数据传输到Zabbix、阿里云、华为云、亚马逊AWS、ThingsBoard、Ignition云平台

随着工业4.0的深入发展和物联网技术的广泛应用&#xff0c;IEC 104&#xff08;IEC 60870-5-104&#xff09;作为电力系统中的重要通信协议&#xff0c;正逐步与各种现代监控、管理和云平台实现深度融合。IEC104转MQTT网关BE113作为这一融合过程中的关键设备&#xff0c;其能够…