本系列文章
Intel Distiller工具包-量化实现1
Intel Distiller工具包-量化实现2
Distiller
- Distiller是Intel 2019年左右开发的一个支持神经网络压缩的工具包,支持的方法包括 剪枝、量化、蒸馏、低稚分解等;
- 本文介绍Distiller量化方案是如何实现的;由于Distiller 19年后几乎不再更新,因此主要介绍经典量化方案,用于学习;
Distiller量化实现
- 首先,我将引用Distiller examples内实现的gnmt量化代码,通过该例子介绍distiller量化框架;代码如下图
- 可以看到上述代码有3个步骤
- 收集统计数据(又称校准器,QuantCalibrationStatsCollector)
- 创建量化器(此处是后训练量化器 PostTrainLinearQuantizer)
- 量化器量化模型(prepare_model)
- 以下主要关注量化器的创建和应用(上述2/3步)
Quantizer
- distiller实现量化器的主要思路 是 module替换 法,即将要量化的模块,如conv、linear、embedding、无参ops(如加、乘、concat)等,用一个封装器(wrapper)封装起来;推断时wrapper对输入、权重(无参module没有)进行量化,然后交由被封装的真实模块(conv、linear、无参op等)进行计算,最后根据需要再做反量化;
- 注:distiller将加法(elementwise_add)、逐元素乘法(点积/elementwise_mul)、矩阵乘法(matMul)、批矩阵乘法(BatchMatMul)、concat都用nn.Module进行封装,以便distiller能用 module替换 法 进行统一量化处理;
- 现在来看一下Quantizer基类的定义,代码如下(做了注释,但读者如果没有看过distiller完整源码,还可能看不明白,推荐感兴趣读者去看看源码),主要分成如下几个部分
- 重要变量
- 量化bits设置:各module的默认设置、外部覆盖设置(overrides);定义了Qbits类
- 量化替换工厂:即replacement_factory,以dict形式记录了替换待量化module的wrapper;这个参数由具体Quantizer子类设置,下一篇文章会介绍;
- 待量化参数:params_to_quantize,记录所有待量化参数及其相关情况(所在module、量化bits等)
- 参数量化函数:param_quantization_fn,对参数进行量化的函数,由Quantizer子类设置
- 处理流程(prepare_model)
- 预处理:如BN折叠(训练和推断的BN用法不一样)、激活优化等
- 量化替换:将待量化module用相应的wrapper替换;
- 后处理
- 源码如下
class Quantizer(object):r"""Base class for quantizers.Args:model (torch.nn.Module): The model to be quantizedoptimizer (torch.optim.Optimizer): An optimizer instance, required in cases where the quantizer is goingto perform changes to existing model parameters and/or add new ones.Specifically, when train_with_fp_copy is True, this cannot be None.bits_activations/weights/bias (int): Default number of bits to use when quantizing each tensor type.Value of None means do not quantize.overrides (OrderedDict): Dictionary mapping regular expressions of layer name patterns to dictionary withoverrides of default values.The keys in the overrides dictionary should be parameter names that the Quantizer accepts default valuesfor in its init function.The parameters 'bits_activations', 'bits_weights', and 'bits_bias' which are accepted by the base Quantizerare supported by default.Other than those, each sub-class of Quantizer defines the set of parameter for which it supportsover-riding.OrderedDict is used to enable handling of overlapping name patterns. So, for example, one could definecertain override parameters for a group of layers, e.g. 'conv*', but also define different parameters forspecific layers in that group, e.g. 'conv1'.The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns mustcome before the broad patterns.train_with_fp_copy (bool): If true, will modify layers with weights to keep both a quantized andfloating-point copy, such that the following flow occurs in each training iteration:1. q_weights = quantize(fp_weights)2. Forward through network using q_weights3. In back-prop:3.1 Gradients calculated with respect to q_weights3.2 We also back-prop through the 'quantize' operation from step 14. Update fp_weights with gradients calculated in step 3.2"""def __init__(self, model, optimizer=None,bits_activations=None, bits_weights=None, bits_bias=None,overrides=None, train_with_fp_copy=False):if overrides is None:overrides = OrderedDict()if not isinstance(overrides, OrderedDict):raise TypeError('overrides must be an instance of collections.OrderedDict or None')if train_with_fp_copy and optimizer is None:raise ValueError('optimizer cannot be None when train_with_fp_copy is True')# 获取计算图节点间关系,以便后续进行激活函数优化self.adjacency_map = None # To be populated during prepare_model()# 默认的量化bits设置self.default_qbits = QBits(acts=bits_activations, wts=bits_weights, bias=bits_bias)self.overrides = overridesself.model = modelself.optimizer = optimizer# Stash some quantizer data in the model so we can re-apply the quantizer on a resuming modelself.model.quantizer_metadata = {'type': type(self),'params': {'bits_activations': bits_activations,'bits_weights': bits_weights,'bits_bias': bits_bias,'overrides': copy.deepcopy(overrides)}}for k, v in self.overrides.items():if any(old_bits_key in v.keys() for old_bits_key in ['acts', 'wts', 'bias']):raise ValueError("Using 'acts' / 'wts' / 'bias' to specify bit-width overrides is deprecated.\n""Please use the full parameter names: ""'bits_activations' / 'bits_weights' / 'bits_bias'")qbits = QBits(acts=v.pop('bits_activations', self.default_qbits.acts),wts=v.pop('bits_weights', self.default_qbits.wts),bias=v.pop('bits_bias', self.default_qbits.bias))v['bits'] = qbits# Prepare explicit mapping from each layer to QBits based on default + overridespatterns = []regex_overrides = None# 需要覆盖部分module的默认量化设置if overrides:patterns = list(overrides.keys())regex_overrides_str = '|'.join(['(^{0}$)'.format(pattern) for pattern in patterns])regex_overrides = re.compile(regex_overrides_str)self.module_qbits_map = {}self.module_overrides_map = {}# 设置各module的量化bitsfor module_full_name, module in model.named_modules():# Need to account for scenario where model is parallelized with DataParallel, which wraps the original# module with a wrapper module called 'module' :)name_to_match = module_full_name.replace('module.', '', 1)qbits = self.default_qbitsoverride_entry = self.overrides.get(name_to_match, OrderedDict())if regex_overrides:m_overrides = regex_overrides.match(name_to_match)if m_overrides:group_idx = 0groups = m_overrides.groups()while groups[group_idx] is None:group_idx += 1override_entry = copy.deepcopy(override_entry or self.overrides[patterns[group_idx]])qbits = override_entry.pop('bits', self.default_qbits)self._add_qbits_entry(module_full_name, type(module), qbits)self._add_override_entry(module_full_name, override_entry)# Mapping from module type to function generating a replacement module suited for quantization# To be populated by child classes# Unspecified layer types return None by default.self.replacement_factory = defaultdict(lambda: None)# Pointer to parameters quantization function, triggered during training process# To be populated by child classesself.param_quantization_fn = None # 参数量化函数self.train_with_fp_copy = train_with_fp_copyself.params_to_quantize = []# A dictionary of replaced modules and their respective names.self.modules_processed = OrderedDict() # 已被处理的moduledef _add_qbits_entry(self, module_name, module_type, qbits):if module_type not in [nn.Conv2d, nn.Conv3d, nn.Linear, nn.Embedding]:# For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't# support quantization of batch norm scale parameters)qbits = QBits(acts=qbits.acts, wts=None, bias=None)self.module_qbits_map[module_name] = qbitsdef _add_override_entry(self, module_name, entry):self.module_overrides_map[module_name] = entry# def prepare_model(self, dummy_input=None):"""Traverses the model and replaces sub-modules with quantized counterparts according to the bit-widthand overrides configuration provided to __init__(), and according to the replacement_factory asdefined by the Quantizer sub-class being used.Note:If multiple sub-modules within the model actually reference the same module, then that moduleis replaced only once, according to the configuration (bit-width and/or overrides) of thefirst encountered reference.Toy Example - say a module is constructed using this bit of code:shared_relu = nn.ReLUself.relu1 = shared_reluself.relu2 = shared_reluWhen traversing the model, a replacement will be generated when 'self.relu1' is encountered.Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replacedwith a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2'will be ignored. A warning message will be shown."""msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__))self.model.quantizer_metadata["dummy_input"] = dummy_inputif dummy_input is not None:summary_graph = distiller.SummaryGraph(self.model, dummy_input)# 获取adjacency_mapself.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False)self._pre_prepare_model(dummy_input) # 预处理:BN折叠、带激活的module优化等self._pre_process_container(self.model) # 开始执行量化module替代等主要工作for module_name, module in self.model.named_modules():qbits = self.module_qbits_map[module_name]curr_parameters = dict(module.named_parameters())for param_name, param in curr_parameters.items():n_bits = qbits.bias if param_name.endswith('bias') else qbits.wtsif n_bits is None:continuefp_attr_name = param_nameif self.train_with_fp_copy:hack_float_backup_parameter(module, param_name, n_bits) # 备份float参数fp_attr_name = FP_BKP_PREFIX + param_name# 记录待量化参数的相关信息:所在module,是否有fp copy,量化设置self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits))param_full_name = '.'.join([module_name, param_name])msglogger.info("Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits))# If an optimizer was passed, assume we need to update it# 优化器可能需要更新(如新增了参数)if self.optimizer:optimizer_type = type(self.optimizer)new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults)self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups})self._post_prepare_model() # 后处理msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))def _pre_prepare_model(self, dummy_input):passdef _pre_process_container(self, container, prefix=''):def replace_msg(module_name, modules=None):msglogger.debug('Module ' + module_name)if modules:msglogger.debug('\tReplacing: {}.{}'.format(modules[0].__module__, modules[0].__class__.__name__))msglogger.debug('\tWith: {}.{}'.format(modules[1].__module__, modules[1].__class__.__name__))else:msglogger.debug('\tSkipping')# Iterate through model, insert quantization functions as appropriate# 遍历model内各个module,执行 量化模块 替代for name, module in container.named_children():full_name = prefix + nameif module in self.modules_processed:previous_name, previous_wrapper = self.modules_processed[module]warnings.warn("Module '{0}' references to same module as '{1}'."' Replacing with reference the same wrapper.'.format(full_name, previous_name),UserWarning)if previous_wrapper:replace_msg(full_name, (module, previous_wrapper))setattr(container, name, previous_wrapper)else:replace_msg(full_name)continuecurrent_qbits = self.module_qbits_map[full_name]if current_qbits.acts is None and current_qbits.wts is None:if self.module_overrides_map[full_name]:raise ValueError("Adding overrides while not quantizing is not allowed.")# We indicate this module wasn't replaced by a wrapper 不做替代replace_msg(full_name)self.modules_processed[module] = full_name, Noneelse:# We use a type hint comment to let IDEs know replace_fn is a function# 获取待量化module的wrapper(即replace_fn,下文介绍)replace_fn = self.replacement_factory[type(module)] # type: Optional[Callable]# If the replacement function wasn't specified - continue without replacing this module.if replace_fn is not None:valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name],replace_fn)if invalid_kwargs:raise TypeError("""Quantizer of type %s doesn't accept \"%s\" as override arguments for %s. Allowed kwargs: %s"""% (type(self), list(invalid_kwargs), type(module), list(valid_kwargs)))# 替换要量化的module为封装modulenew_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs)replace_msg(full_name, (module, new_module))# Add to history of prepared submodulesself.modules_processed[module] = full_name, new_modulesetattr(container, name, new_module)# If a "leaf" module was replaced by a container, add the new layers to the QBits mappingif not distiller.has_children(module) and distiller.has_children(new_module):for sub_module_name, sub_module in new_module.named_modules():self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits)self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None)if distiller.has_children(module):# For container we call recursivelyself._pre_process_container(module, full_name + '.')def _get_updated_optimizer_params_groups(self):"""Returns a list of model parameter groups and optimizer hyper-parameter overrides,as expected by the __init__ function of torch.optim.Optimizer.This is called after all model changes were made in prepare_model, in case an Optimizer instance waspassed to __init__.Subclasses which add parameters to the model should override as needed.:return: List of parameter groups"""# Default implementation - just return all model parameters as one groupreturn [{'params': self.model.parameters()}]def _post_prepare_model(self):passdef quantize_params(self):"""Quantize all parameters using self.param_quantization_fn (with the defined number of bits for each parameter)"""for ptq in self.params_to_quantize:q_param = self.param_quantization_fn(getattr(ptq.module, ptq.fp_attr_name), ptq)if self.train_with_fp_copy:setattr(ptq.module, ptq.q_attr_name, q_param)else:getattr(ptq.module, ptq.q_attr_name).data = q_param.data
- 重要变量
小结
本文介绍了distiller及其量化功能的部分实现,主要是简单介绍了Quantizer这个基类的实现;后续具体的量化器实现均继承自该基类;
由于代码较长,考虑篇幅,具体量化器的实现将在后续文章中(Intel Distiller工具包-量化实现2)介绍;