Intel Distiller工具包-量化实现1

本系列文章

Intel Distiller工具包-量化实现1

Intel Distiller工具包-量化实现2


Distiller

  • Distiller是Intel 2019年左右开发的一个支持神经网络压缩的工具包,支持的方法包括 剪枝、量化、蒸馏、低稚分解等;
  • 本文介绍Distiller量化方案是如何实现的;由于Distiller 19年后几乎不再更新,因此主要介绍经典量化方案,用于学习;

Distiller量化实现

  • 首先,我将引用Distiller examples内实现的gnmt量化代码,通过该例子介绍distiller量化框架;代码如下图
  • 可以看到上述代码有3个步骤
    1. 收集统计数据(又称校准器,QuantCalibrationStatsCollector)
    2. 创建量化器(此处是后训练量化器 PostTrainLinearQuantizer)
    3. 量化器量化模型(prepare_model)
  • 以下主要关注量化器的创建和应用(上述2/3步)

Quantizer

  • distiller实现量化器的主要思路 是 module替换 法,即将要量化的模块,如conv、linear、embedding、无参ops(如加、乘、concat)等,用一个封装器(wrapper)封装起来;推断时wrapper对输入、权重(无参module没有)进行量化,然后交由被封装的真实模块(conv、linear、无参op等)进行计算,最后根据需要再做反量化;
    • 注:distiller将加法(elementwise_add)、逐元素乘法(点积/elementwise_mul)、矩阵乘法(matMul)、批矩阵乘法(BatchMatMul)、concat都用nn.Module进行封装,以便distiller能用 module替换 法 进行统一量化处理;
  • 现在来看一下Quantizer基类的定义,代码如下(做了注释,但读者如果没有看过distiller完整源码,还可能看不明白,推荐感兴趣读者去看看源码),主要分成如下几个部分
    • 重要变量
      • 量化bits设置:各module的默认设置、外部覆盖设置(overrides);定义了Qbits类
      • 量化替换工厂:即replacement_factory,以dict形式记录了替换待量化module的wrapper;这个参数由具体Quantizer子类设置,下一篇文章会介绍;
      • 待量化参数:params_to_quantize,记录所有待量化参数及其相关情况(所在module、量化bits等)
      • 参数量化函数:param_quantization_fn,对参数进行量化的函数,由Quantizer子类设置
    • 处理流程(prepare_model)
      • 预处理:如BN折叠(训练和推断的BN用法不一样)、激活优化等
      • 量化替换:将待量化module用相应的wrapper替换;
      • 后处理
    • 源码如下
      class Quantizer(object):r"""Base class for quantizers.Args:model (torch.nn.Module): The model to be quantizedoptimizer (torch.optim.Optimizer): An optimizer instance, required in cases where the quantizer is goingto perform changes to existing model parameters and/or add new ones.Specifically, when train_with_fp_copy is True, this cannot be None.bits_activations/weights/bias (int): Default number of bits to use when quantizing each tensor type.Value of None means do not quantize.overrides (OrderedDict): Dictionary mapping regular expressions of layer name patterns to dictionary withoverrides of default values.The keys in the overrides dictionary should be parameter names that the Quantizer accepts default valuesfor in its init function.The parameters 'bits_activations', 'bits_weights', and 'bits_bias' which are accepted by the base Quantizerare supported by default.Other than those, each sub-class of Quantizer defines the set of parameter for which it supportsover-riding.OrderedDict is used to enable handling of overlapping name patterns. So, for example, one could definecertain override parameters for a group of layers, e.g. 'conv*', but also define different parameters forspecific layers in that group, e.g. 'conv1'.The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns mustcome before the broad patterns.train_with_fp_copy (bool): If true, will modify layers with weights to keep both a quantized andfloating-point copy, such that the following flow occurs in each training iteration:1. q_weights = quantize(fp_weights)2. Forward through network using q_weights3. In back-prop:3.1 Gradients calculated with respect to q_weights3.2 We also back-prop through the 'quantize' operation from step 14. Update fp_weights with gradients calculated in step 3.2"""def __init__(self, model, optimizer=None,bits_activations=None, bits_weights=None, bits_bias=None,overrides=None, train_with_fp_copy=False):if overrides is None:overrides = OrderedDict()if not isinstance(overrides, OrderedDict):raise TypeError('overrides must be an instance of collections.OrderedDict or None')if train_with_fp_copy and optimizer is None:raise ValueError('optimizer cannot be None when train_with_fp_copy is True')# 获取计算图节点间关系,以便后续进行激活函数优化self.adjacency_map = None  # To be populated during prepare_model()# 默认的量化bits设置self.default_qbits = QBits(acts=bits_activations, wts=bits_weights, bias=bits_bias)self.overrides = overridesself.model = modelself.optimizer = optimizer# Stash some quantizer data in the model so we can re-apply the quantizer on a resuming modelself.model.quantizer_metadata = {'type': type(self),'params': {'bits_activations': bits_activations,'bits_weights': bits_weights,'bits_bias': bits_bias,'overrides': copy.deepcopy(overrides)}}for k, v in self.overrides.items():if any(old_bits_key in v.keys() for old_bits_key in ['acts', 'wts', 'bias']):raise ValueError("Using 'acts' / 'wts' / 'bias' to specify bit-width overrides is deprecated.\n""Please use the full parameter names: ""'bits_activations' / 'bits_weights' / 'bits_bias'")qbits = QBits(acts=v.pop('bits_activations', self.default_qbits.acts),wts=v.pop('bits_weights', self.default_qbits.wts),bias=v.pop('bits_bias', self.default_qbits.bias))v['bits'] = qbits# Prepare explicit mapping from each layer to QBits based on default + overridespatterns = []regex_overrides = None# 需要覆盖部分module的默认量化设置if overrides:patterns = list(overrides.keys())regex_overrides_str = '|'.join(['(^{0}$)'.format(pattern) for pattern in patterns])regex_overrides = re.compile(regex_overrides_str)self.module_qbits_map = {}self.module_overrides_map = {}# 设置各module的量化bitsfor module_full_name, module in model.named_modules():# Need to account for scenario where model is parallelized with DataParallel, which wraps the original# module with a wrapper module called 'module' :)name_to_match = module_full_name.replace('module.', '', 1)qbits = self.default_qbitsoverride_entry = self.overrides.get(name_to_match, OrderedDict())if regex_overrides:m_overrides = regex_overrides.match(name_to_match)if m_overrides:group_idx = 0groups = m_overrides.groups()while groups[group_idx] is None:group_idx += 1override_entry = copy.deepcopy(override_entry or self.overrides[patterns[group_idx]])qbits = override_entry.pop('bits', self.default_qbits)self._add_qbits_entry(module_full_name, type(module), qbits)self._add_override_entry(module_full_name, override_entry)# Mapping from module type to function generating a replacement module suited for quantization# To be populated by child classes# Unspecified layer types return None by default.self.replacement_factory = defaultdict(lambda: None)# Pointer to parameters quantization function, triggered during training process# To be populated by child classesself.param_quantization_fn = None  # 参数量化函数self.train_with_fp_copy = train_with_fp_copyself.params_to_quantize = []# A dictionary of replaced modules and their respective names.self.modules_processed = OrderedDict()  # 已被处理的moduledef _add_qbits_entry(self, module_name, module_type, qbits):if module_type not in [nn.Conv2d, nn.Conv3d, nn.Linear, nn.Embedding]:# For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't# support quantization of batch norm scale parameters)qbits = QBits(acts=qbits.acts, wts=None, bias=None)self.module_qbits_map[module_name] = qbitsdef _add_override_entry(self, module_name, entry):self.module_overrides_map[module_name] = entry# def prepare_model(self, dummy_input=None):"""Traverses the model and replaces sub-modules with quantized counterparts according to the bit-widthand overrides configuration provided to __init__(), and according to the replacement_factory asdefined by the Quantizer sub-class being used.Note:If multiple sub-modules within the model actually reference the same module, then that moduleis replaced only once, according to the configuration (bit-width and/or overrides) of thefirst encountered reference.Toy Example - say a module is constructed using this bit of code:shared_relu = nn.ReLUself.relu1 = shared_reluself.relu2 = shared_reluWhen traversing the model, a replacement will be generated when 'self.relu1' is encountered.Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replacedwith a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2'will be ignored. A warning message will be shown."""msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__))self.model.quantizer_metadata["dummy_input"] = dummy_inputif dummy_input is not None:summary_graph = distiller.SummaryGraph(self.model, dummy_input)# 获取adjacency_mapself.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False)self._pre_prepare_model(dummy_input)  # 预处理:BN折叠、带激活的module优化等self._pre_process_container(self.model)  # 开始执行量化module替代等主要工作for module_name, module in self.model.named_modules():qbits = self.module_qbits_map[module_name]curr_parameters = dict(module.named_parameters())for param_name, param in curr_parameters.items():n_bits = qbits.bias if param_name.endswith('bias') else qbits.wtsif n_bits is None:continuefp_attr_name = param_nameif self.train_with_fp_copy:hack_float_backup_parameter(module, param_name, n_bits)  # 备份float参数fp_attr_name = FP_BKP_PREFIX + param_name# 记录待量化参数的相关信息:所在module,是否有fp copy,量化设置self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits))param_full_name = '.'.join([module_name, param_name])msglogger.info("Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits))# If an optimizer was passed, assume we need to update it# 优化器可能需要更新(如新增了参数)if self.optimizer:optimizer_type = type(self.optimizer)new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults)self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups})self._post_prepare_model()  # 后处理msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))def _pre_prepare_model(self, dummy_input):passdef _pre_process_container(self, container, prefix=''):def replace_msg(module_name, modules=None):msglogger.debug('Module ' + module_name)if modules:msglogger.debug('\tReplacing: {}.{}'.format(modules[0].__module__, modules[0].__class__.__name__))msglogger.debug('\tWith:      {}.{}'.format(modules[1].__module__, modules[1].__class__.__name__))else:msglogger.debug('\tSkipping')# Iterate through model, insert quantization functions as appropriate# 遍历model内各个module,执行 量化模块 替代for name, module in container.named_children():full_name = prefix + nameif module in self.modules_processed:previous_name, previous_wrapper = self.modules_processed[module]warnings.warn("Module '{0}' references to same module as '{1}'."' Replacing with reference the same wrapper.'.format(full_name, previous_name),UserWarning)if previous_wrapper:replace_msg(full_name, (module, previous_wrapper))setattr(container, name, previous_wrapper)else:replace_msg(full_name)continuecurrent_qbits = self.module_qbits_map[full_name]if current_qbits.acts is None and current_qbits.wts is None:if self.module_overrides_map[full_name]:raise ValueError("Adding overrides while not quantizing is not allowed.")# We indicate this module wasn't replaced by a wrapper 不做替代replace_msg(full_name)self.modules_processed[module] = full_name, Noneelse:# We use a type hint comment to let IDEs know replace_fn is a function# 获取待量化module的wrapper(即replace_fn,下文介绍)replace_fn = self.replacement_factory[type(module)]  # type: Optional[Callable]# If the replacement function wasn't specified - continue without replacing this module.if replace_fn is not None:valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name],replace_fn)if invalid_kwargs:raise TypeError("""Quantizer of type %s doesn't accept \"%s\" as override arguments for %s. Allowed kwargs: %s"""% (type(self), list(invalid_kwargs), type(module), list(valid_kwargs)))# 替换要量化的module为封装modulenew_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs)replace_msg(full_name, (module, new_module))# Add to history of prepared submodulesself.modules_processed[module] = full_name, new_modulesetattr(container, name, new_module)# If a "leaf" module was replaced by a container, add the new layers to the QBits mappingif not distiller.has_children(module) and distiller.has_children(new_module):for sub_module_name, sub_module in new_module.named_modules():self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits)self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None)if distiller.has_children(module):# For container we call recursivelyself._pre_process_container(module, full_name + '.')def _get_updated_optimizer_params_groups(self):"""Returns a list of model parameter groups and optimizer hyper-parameter overrides,as expected by the __init__ function of torch.optim.Optimizer.This is called after all model changes were made in prepare_model, in case an Optimizer instance waspassed to __init__.Subclasses which add parameters to the model should override as needed.:return: List of parameter groups"""# Default implementation - just return all model parameters as one groupreturn [{'params': self.model.parameters()}]def _post_prepare_model(self):passdef quantize_params(self):"""Quantize all parameters using self.param_quantization_fn (with the defined number of bits for each parameter)"""for ptq in self.params_to_quantize:q_param = self.param_quantization_fn(getattr(ptq.module, ptq.fp_attr_name), ptq)if self.train_with_fp_copy:setattr(ptq.module, ptq.q_attr_name, q_param)else:getattr(ptq.module, ptq.q_attr_name).data = q_param.data
      

小结

        本文介绍了distiller及其量化功能的部分实现,主要是简单介绍了Quantizer这个基类的实现;后续具体的量化器实现均继承自该基类;

        由于代码较长,考虑篇幅,具体量化器的实现将在后续文章中(Intel Distiller工具包-量化实现2)介绍;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/1381828.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

【知识蒸馏】Channel-wise Knowledge Distillation for Dense Prediction

文章目录 一、背景二、动机三、方法3.1 回顾 Spatial Distillation3.2 Channel-wise Distillation 四、效果五、训练和测试六、代码解析 论文链接:https://arxiv.org/pdf/2011.13256.pdf 代码链接:https://github.com/irfanICMLL/TorchDistiller MMDet…

Distiller:量化算法

Quantization Algorithms 量化算法 注意: 对于任何需要量化感知训练的以下方法,请参阅这里,了解如何使用Distiller的机制调用它。 基于范围的线性量化(Range-Based Linear Quantization) 让我们在此分解使用的术语:…

模型压缩工具Distiller-剪枝

1.distiller剪枝模块的使用 (1)distiller自带剪枝实例测试 distiller自带一些测试实例如ResNet56cifar-10,下面是对ResNet56cifar-10的测试: 测试前准备 yaml文件(注意:这里的yaml文件是coder配置好的,具体…

模型压缩工具Distiller-INT8量化

1.distiller工具介绍 Distiller是一个开源的Python软件包,用于神经网络压缩研究。网络压缩可以减少神经网络的内存占用,提高推理速度并节省能源。Distiller提供了一个PyTorch环境,用于对压缩算法进行原型设计和分析。 主要功能: …

Z3约束器详细学习(0)—Z3安装|语句详解

推荐肉丝r0ysue课程(包含安卓逆向与js逆向): 参考资料: Z3 API IN PYTHON 中文文档 1. Z3安装 linux安装Z3 git clone https://github.com/angr/angr-z3.git cd angr-z3 python scripts/mk_make.py cd build make sudo make …

z3求解器(SMT)解各类方程各种逻辑题非常简单直观

各位小伙伴大家好,今天我将给大家演示一个非常高级的工具,SMT求解器。应用领域非常广,解各类方程,解各类编程问题(例如解数独),解逻辑题等都不在话下。 今天小小明就将带大家看看这其中的精彩&…

z3学习笔记(有空继续整理)

一、基本语法 Declare-const: 声明给定类型(type/ sort)的常量 declare-fun:声明一个函数 (declare-fun f (Int Bool) Int):声明一个接收整型和布尔型两个参数的函数,返回int (define-fun a () Int [val])&#xf…

生成带参数的二维码

获取带参数的二维码的过程包括两步,首先创建二维码ticket,然后凭借ticket到指定URL换取二维码。 首先,创建二维码ticket 参考一下参数 临时二维码请求说明 http请求方式: POST URL: https://api.weixin.qq.com/cgi-bin/qrcode/create?access_token=TOKEN POST数据格式:j…

DH参数例子-SCARA机器人

建议先阅读<上一篇>。 DH参数分配 此处说到的SCARA机器人是KUKA KR10机器人&#xff1a; 它是一个revolute_revolute_prismatic_revolute结构或者简称为RRPR结构&#xff0c;并且所有的关节轴都是平行的。 步骤&#xff11;&#xff1a;从{1&#xff0c;2&#xff0c…

约束求解器-Z3

关于z3 Z3 是一个微软出品的开源约束求解器&#xff0c;能够解决很多种情况下的给定部分约束条件寻求一组满足条件的解的问题&#xff08;可以简单理解为解方程的感觉&#xff0c;虽然这么比喻其实还差距甚远&#xff0c;请勿吐槽&#xff09;&#xff0c;功能强大且易于使用&a…

约束求解器Z3

目录 预备知识1.关于z3 实验目的实验环境实验步骤一实验步骤二实验步骤三 预备知识 1.关于z3 Z3是一个微软出品的开源约束求解器&#xff0c;能够解决很多种情况下的给定部分约束条件寻求一组满足条件的解的问题&#xff08;可以简单理解为解方程的感觉&#xff0c;虽然这么比…

Geomesa-HBase索引篇——Z3索引

目录 1. 概述 2. 原理 2.1 概述 2.2 分片存储机制 2.3 Epoch Week机制 2.4 时空索引机制 2.5 Fid机制 2.6 多个数据的封装 3. 代码实现 3.1 获取分片 3.2 获取Epoch Week 3.3 获取时空索引 3.4 获取Fid 3.5 多个数据的封装 1. 概述 在大量的场景当中&#xff0c…

matlab函数参数不足,调用函数显示输入参数不足

问题描述.png (29.7 KB, 下载次数: 1) 2015-1-27 09:34 上传 %Gauss-Newton算法实现如下 function[x,minf] = GN(f,x0,var,eps)formatlong; ifnargin == 3 %如果没有设置eps,则eps=1.0e-6eps = 1.0e-6; end m = 0; S =transpose(f)*f; %trnspose是转…

mark点Z3学习资料整理

文章目录 Anything is NothingLess is MoreSMTz3 classeslogic programming Reasoning符号推理策略strategiesFixed-point关系代数datalog程序分析验证 Anything is Nothing 前几个月科研用到z3-solver&#xff0c;学习了下&#xff0c;主要注释写在源码上进行学习和试验&…

z3 guide

Z3是微软研究院开发的高性能定理证明程序。Z3有许多应用场合&#xff0c;如:软件/硬件验证和测试&#xff0c;约束求解&#xff0c;混合系统的分析&#xff0c;安全&#xff0c;生物(硅分析)&#xff0c;几何问题。 Z3发行版还包含C、C、.Net、Java和OCaml 的api。Z3Py的源代码…

【Django】无法从“django.utils.encoding”导入名称“force_text”

整晚处理 Django 的导入错误。 我将把它作为提醒&#xff0c;希望处于相同情况的人数会减少。 原因 某些软件包版本不支持Django 4 请看下表并决定Django和Python的版本 方案 如果出现难以响应&#xff0c;或者更改环境麻烦&#xff0c;请尝试以下操作 例如出现以下错误 …

走迷宫(maze) 难度**

题目描述 有一个 mn 格的迷宫(表示有 m 行、n 列)&#xff0c;其中有可走的也有不可走的&#xff0c;如果用 11 表示可以走&#xff0c;00 表示不可以走。 文件读入这 mn 个数据和起 始点、结束点(起始点和结束点都是用两个数据来描述的&#xff0c;分别表示这个点的行号和列…

地下迷宫

import java.util.*;/*** 题目大意:n*m格迷宫,1代表青蛙可以通过,0不能通过* 青蛙体力值P,每次走一步,横向走消耗体力值1,向下走不消耗体力,* 向上走消耗体力值3.* 青蛙初始位置(0,0),迷宫出口(0,m-1)* 求青蛙走出迷宫的路径*/ public class Main {static class Node {int x;in…

7-2 地下迷宫探索

7-2 地下迷宫探索 分数 30 全屏浏览题目 切换布局 作者 DS课程组 单位 浙江大学 地道战是在抗日战争时期&#xff0c;在华北平原上抗日军民利用地道打击日本侵略者的作战方式。地道网是房连房、街连街、村连村的地下工事&#xff0c;如下图所示。 我们在回顾前辈们艰苦卓绝…

走迷宫图解

本节利用栈的思想用试探法进行了迷宫一条路径的探索。其中主要的操作是找到下一个空格、如果空格不再可行退回上一个格、每走过一个格子将走过的格子标记为-1防止循环走。 原理如下图&#xff1a; 具体的代码如下&#xff1a; #走一个任意的5*5迷宫 #mg可以为[[1,1,1,1,1,1],…