项目地址:https://github.com/Syencil/mobile-yolov5-pruning-distillation
项目时间:2022年
mobile-yolov5-pruning-distillation是一个以yolov5改进为主的开源项目,主要包含3中改进方向:更改backbone、模型剪枝、知识蒸馏。这里主要研究其模型剪枝部分,关于知识蒸馏后续在进行分析。关于更改非coco训出的backbone(使用moblienet替换),可以发现存在相助的精度下降,这表明imagenet域训练处的权重迁移到目标检测领域不如二次迁移的模型(先imagenet,再coco训练)。
从项目代码分析中,主要学习到bn层的稀疏化训练是如何实现的(章节2.1)、如果按照权重大小对模型进行剪枝(章节2.3)
yolov5s在640x640分辨率下的计算量和参数量分别为8.39G和7.07M。在速度上仍然有提升空间,通过替换backbone(mobilenetv2),通道剪枝对模型进行压缩。 利用蒸馏提升小模型的精度。项目以工程化为基础,主要是模型端的优化。实现了常用的剪枝和蒸馏算法,并对其做了一个简单的介绍和评估。通过在voc上实验来验证各个方法的有效性。 最后将工程可用模型转换成对应部署版本。
1、剪枝操作
1.1 baseline
数据集采用Pascal VOC,trainset = train2007+train2012+val2007+val2012,testset = test2007,Baseline采用mobile-yolo(imagenet预训练),如果没有特别说明,第一个模块采用Focus 如果未经特殊说明则均为使用默认参数,batchsize=24,epoch=50,train_size = 640,test_size = 640,conf_thres=0.001,iou_thres=0.6,mAP均为50
PS. 由于资源有限,此项目只训练50个epoch,实际上可以通过调整学习率和迭代次数进一步提高mAP。但是可以通过控制相同的超参数来进行实验对比,所以并不影响最终结果。
baseline由4个部分组成:yolov5s,官方提供的coco权重在voc上进行微调所以不具备可比性,但是可以作为蒸馏指导模型;mobilev2-yolo5s和mobilev2-yolo5l均是只更改了对应的backbone;mobilev2-yolo3则是用的yolo3head,结构同keras-YOLOv3-mobilenet 基本一致(keras的是mobilev1,参数量和计算量更大),此处作为参照物。
1.2 模型剪枝
选取mobilev2-yolo5s作为剪枝的基础模型。以以下策略为基础:
1、输出层不动,统计其他所有BN层的weight分布
2、根据稀疏率决定剪枝阈值
3、开始剪枝,如果当前层所有值均小于阈值则保留最大的一个通道(保证结构不被破坏)
先从头训练一个baseline,以及训练一个对bn中gamma参数加入L1正则化的网络。稀疏参数为sl=6e-4。结果比baseline掉了3个点。
剪枝策略按照论文中的做法给定一个稀疏率,统计所有参与剪枝层的bn参数l1值并进行排序,依据稀疏率确定阈值。
将所有小于阈值的层全部减掉,如果有依赖则将依赖的对应部分也剪掉。如果一层中所有的层都需要被移除,那么就保留最大的一层通道(保证网络结构)。
根据以下图表数据,可以发现剪枝了约25%的参数(主要看pruning2),模型预计加速1.25x,map下降了0.2,可能是训练不充分导致的。
1.3 与知识蒸馏的对比
作者在后面也分享了知识蒸馏的结果,可以发现蒸馏后的mobilenet比原始的模型高2哥百分点,map达74%,相比于对大模型进行剪枝,或许知识蒸馏更能保持精度。
对蒸馏后的模型进行剪枝或许能再度提升模型的运行速度。
知识蒸馏信息如下:
将特征图和输出层一起作为蒸馏指导。对于T和S中间特征图输出维度不匹配的问题,采用在S网络输出接一个Converter,将其升维到T网络匹配。 Converter由conv+bn+relu6组成,T网络输出单独接一个relu6,保证激活函数相同。(上个commit版本出现了一个bug,导致精度没变其实是不对的,现已修正) output层参数为1.0,feature参数为0.5。mAP0.663甚至比baseline都要低。feature distillation居然让模型掉点了,怀疑是feature权重太大,降到0.1667,mAP可以提升到0.68,还是低于baseline。 继续下降到0.05,mAP可以回到baseline的水平,不过在训练末期mAP还在上升,loss还在下降。最后尝试训练100个epoch,mAP才回到74。 实际上还尝试过各种变形和各种参数,但是感觉效果仍然不好。
2、其中关键代码
2.1 bn层稀疏化训练
bn层稀疏化训练的原理可以查看 https://blog.csdn.net/a486259/article/details/140594912 中的1.1 BN层剪枝的理解基础。
这里主要分析如何实现稀疏化训练,相关代码在train.py中,相关关键代码为sl。进行检索,发现多处关键代码。
关键一: bn稀疏化设置
通过以下代码,为bn gamma值稀疏化,设置稀疏率sl与weight_decay(l2正则化)值相同的缩放比例。同时,将bn参数单独加入到pg0组中,设置单独的学习率;将非bn参数加入到pg1中,设置单独的权重衰减(l2正则化);将所有的bias参数都不进行正则化。
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decayif opt.sl > 0:hyp['sl'] *= batch_size * accumulate / nbspg0, pg1, pg2 = [], [], [] # optimizer parameter groupsfor k, v in model.named_parameters():if v.requires_grad:if '.bias' in k:pg2.append(v) # biaseselif '.weight' in k and '.bn' not in k:pg1.append(v) # apply weight decayelse:pg0.append(v) # all elseoptimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else \optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decayoptimizer.add_param_group({'params': pg2}) # add pg2 (biases)print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))del pg0, pg1, pg2
关键二、稀疏化剪枝loss
对应函数为compute_pruning_loss
对函数深入分析可以发现,就是计算正则化loss
以上代码比较繁琐,但其主要目的是对非bn层参数进行l2正则化,对bn层参数进行l1正则化。l2正则化使得conv、linear层参数趋近于0,l1正则化使大量的bn 层gamma参数值为0。这样使得稀疏化训练后的模型剪枝后精度下降很低。这样的约束会导致模型精度存在轻微下降,但在剪枝后重新训练模型移除约束后,基本上会使精度恢复到约束前。
2.2 可剪枝层获取
通过以下代码获取了模型中所有的bn层
if opt.sl > 0:print("Sparse Learning Model!")print("===> Sparse learning rate is ", opt.sl)ignore_idx = [230, 260, 290]prunable_modules = []prunable_module_type = (nn.BatchNorm2d, )for i, m in enumerate(model.modules()):if i in ignore_idx:continueif isinstance(m, prunable_module_type):prunable_modules.append(m)
2.3 基于bn层的剪枝
相关代码在pruning.py中,主要是channel_prune。
1、先基于bn_analyze对bn层的gamma参数进行统计(一个bn层中有多个gamma参数,bn输出64个channel,就有64个gamma参数);
2、然后遍历可剪枝模块(所有的bn层),根据阈值获取每一个bn层中需剪枝的index
3、基于DG.get_pruning_plan直接对模块中指定index进行剪枝
pos = np.array([i for i in range(len(L1_norm))])pruned_idx_mask = L1_norm < thresprun_index = pos[pruned_idx_mask].tolist()if len(prun_index) == len(L1_norm):del prun_index[np.argmax(L1_norm)]plan = DG.get_pruning_plan(layer_to_prune, prune_fn, prun_index)plan.exec()
bn_analyze函数统计参数过程中绘制的直方图如下
完整代码如下:
def bn_analyze(prunable_modules, save_path=None):bn_val = []max_val = []for layer_to_prune in prunable_modules:# select a layerweight = layer_to_prune.weight.data.detach().cpu().numpy()max_val.append(max(weight))bn_val.extend(weight)bn_val = np.abs(bn_val)max_val = np.abs(max_val)bn_val = sorted(bn_val)max_val = sorted(max_val)plt.hist(bn_val, bins=101, align="mid", log=True, range=(0, 1.0))if save_path is not None:if os.path.isfile(save_path):os.remove(save_path)plt.savefig(save_path)return bn_val, max_valdef channel_prune(ori_model, example_inputs, output_transform, pruned_prob=0.3, thres=None):model = copy.deepcopy(ori_model)model.cpu().eval()prunable_module_type = (nn.BatchNorm2d)ignore_idx = [230, 260, 290]prunable_modules = []for i, m in enumerate(model.modules()):if i in ignore_idx:continueif isinstance(m, prunable_module_type):prunable_modules.append(m)ori_size = tp.utils.count_params(model)DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs,output_transform=output_transform)bn_val, max_val = bn_analyze(prunable_modules, "render_img/before_pruning.jpg")if thres is None:thres_pos = int(pruned_prob * len(bn_val))thres_pos = min(thres_pos, len(bn_val)-1)thres_pos = max(thres_pos, 0)thres = bn_val[thres_pos]print("Min val is %f, Max val is %f, Thres is %f" % (bn_val[0], bn_val[-1], thres))for layer_to_prune in prunable_modules:# select a layerweight = layer_to_prune.weight.data.detach().cpu().numpy()if isinstance(layer_to_prune, nn.Conv2d):if layer_to_prune.groups > 1:prune_fn = tp.prune_group_convelse:prune_fn = tp.prune_convL1_norm = np.sum(np.abs(weight), axis=(1, 2, 3))elif isinstance(layer_to_prune, nn.BatchNorm2d):prune_fn = tp.prune_batchnormL1_norm = np.abs(weight)pos = np.array([i for i in range(len(L1_norm))])pruned_idx_mask = L1_norm < thresprun_index = pos[pruned_idx_mask].tolist()if len(prun_index) == len(L1_norm):del prun_index[np.argmax(L1_norm)]plan = DG.get_pruning_plan(layer_to_prune, prune_fn, prun_index)plan.exec()bn_analyze(prunable_modules, "render_img/after_pruning.jpg")with torch.no_grad():out = model(example_inputs)if output_transform:out = output_transform(out)print(" Params: %s => %s" % (ori_size, tp.utils.count_params(model)))if isinstance(out, (list, tuple)):for o in out:print(" Output: ", o.shape)else:print(" Output: ", out.shape)print("------------------------------------------------------\n")return model