1.distiller剪枝模块的使用
(1)distiller自带剪枝实例测试
distiller自带一些测试实例如ResNet56+cifar-10,下面是对ResNet56+cifar-10的测试:
-
测试前准备
- yaml文件(注意:这里的yaml文件是coder配置好的,具体到自己的模型需要先对自己的model进行一次Sparsity Analysis,然后自己配置该文件) 在剪枝时所用到的yaml文件作用主要是配置了一些剪枝所需要的必要信息,比如下面ResNet56所需要用的yaml配置文件(路径:distiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml):
version: 1 # 版本 pruners: filter_pruner_60: # 后面的60表示剪掉60%的Filters,如[16, 16, 3, 3]剪掉之后就是[7, 16, 3, 3]class: 'L1RankedStructureParameterPruner' # 表示所使用的算法,这里使用L1Rankgroup_type: Filters # 表示剪切类型,一般两种Filters/Channeldesired_sparsity: 0.6 # 剪掉60%的Filtersweights: [ # 下面是一些具体的需要剪切的权值module.layer1.0.conv1.weight,module.layer1.1.conv1.weight,module.layer1.2.conv1.weight,module.layer1.3.conv1.weight,module.layer1.4.conv1.weight,module.layer1.5.conv1.weight,module.layer1.6.conv1.weight,module.layer1.7.conv1.weight,module.layer1.8.conv1.weight]filter_pruner_50: # 同上class: 'L1RankedStructureParameterPruner'group_type: Filtersdesired_sparsity: 0.5weights: [module.layer2.1.conv1.weight,module.layer2.2.conv1.weight,module.layer2.3.conv1.weight,module.layer2.4.conv1.weight,module.layer2.6.conv1.weight,module.layer2.7.conv1.weight]filter_pruner_10: # 同上class: 'L1RankedStructureParameterPruner'group_type: Filtersdesired_sparsity: 0.1weights: [module.layer3.1.conv1.weight]filter_pruner_30: # 同上class: 'L1RankedStructureParameterPruner'group_type: Filtersdesired_sparsity: 0.3weights: [module.layer3.2.conv1.weight,module.layer3.3.conv1.weight,module.layer3.5.conv1.weight,module.layer3.6.conv1.weight,module.layer3.7.conv1.weight,module.layer3.8.conv1.weight]extensions:net_thinner:class: 'FilterRemover'thinning_func_str: remove_filtersarch: 'resnet56_cifar' # 使用的网络dataset: 'cifar10' # 数据集lr_schedulers:exp_finetuning_lr:class: ExponentialLRgamma: 0.95policies:- pruner:instance_name: filter_pruner_60epochs: [0]- pruner:instance_name: filter_pruner_50epochs: [0]- pruner:instance_name: filter_pruner_30epochs: [0]- pruner:instance_name: filter_pruner_10epochs: [0]- extension:instance_name: net_thinnerepochs: [0]- lr_scheduler:instance_name: exp_finetuning_lrstarting_epoch: 10ending_epoch: 300frequency: 1
- 准备ResNet-56需要的模型文件,可下载:
https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar -
剪枝
找到compress_classifier.py文件,如下:
$python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --vs=0
参数解释: -a 表示模型名称(这里是工具自带的模型名称,其他的如resnet32_cifar, resnet44_cifar, resnet56_cifar等等 cifar的模型代码文件位于distiller/models/cifar10/resnet_cifar.py)
-p表示每隔多少打印一次
../../../data.cifar10
是数据集路径
--epochs 表示剪枝过后继续训练次数
--compress 表示所用的‘策略’(compress_scheduler),一般是yaml文件的路径
--resume-from 表示保存的模型的路径
--reset-optimizer 如果设置此参数,那么start_epoch=0,将optimizer重置为SGD, 学习绿设置为传入的学习率
--vs validation-split
具体的其他参数参看distiller/apputils/image_classifier.py文件和distiller/quantization/range_linear.py文件以及github上参数解释。
运行时会对模型进行剪枝,然后在测试集上测试,打印出top1和top5以及loss,运行结束后量化模型会保存在logs下。
(2)distiller对自己的模型剪枝
- 具体流程:1.Sparsity Analysis 分析各层weight的sensitivity,即对模型各个部分的稀疏性和可以pruning的程度有个了解。 2. Yaml file create 创建一个属于自己模型的配置文件,里面各层的稀疏度是由第一阶段分析出来的sensitivity而来的 3.Thinning 对网络进行真正的剪枝。
(3)代码
- 代码已全部整理在 https://github.com/BlossomingL/compression_tool
- distiller修改代码在 https://github.com/BlossomingL/Distiller
- 有用的话麻烦手动点一下小星星,谢谢!