QConfig 详解

定义

QConfig 的定义

模型的量化方式由 qconfig 决定,在准备 qat / calibration 模型之前,需要先给模型设置 qconfig。

注意

因历史原因,Plugin 中有不同 qconfig 的定义和用法,早期版本的 qconfig 将在不久的将来被废弃,我们只推荐您使用此文档中介绍的 qconfig 用法。

一个 qconfig 对象可以设置 input / weight / output 三个关键字,分别表示算子输入/权重/输出的量化配置,prepare 模型时会根据这些配置决定是否要在对应位置插入 FakeQuantize / FakeCast 节点,None 表示不插入任何节点。

import torch from horizon_plugin_pytorch.quantization.qconfig import QConfig from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize from horizon_plugin_pytorch.quantization.fake_cast import FakeCast from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver from horizon_plugin_pytorch.dtype import qint8 qconfig = QConfig( input=None, weight=FakeQuantize.with_args( observer=MinMaxObserver, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, ), output=FakeCast.with_args(dtype=torch.float16), # activation=xxx 早期用法,作用与 output 关键字一致,当前仍兼容,但建议您使用 output 关键字。 )

FakeQuantize 的定义

FakeQuantize 是伪量化节点,会对输入进行量化反量化操作,插入伪量化可以在浮点模型的前向中模拟量化产生的误差。horizon_plugin_pytorch 支持 FakeQuantize / PACTFakeQuantize / _LearnableFakeQuantize 三种伪量化,我们只推荐您使用基于统计的 FakeQuantize,可以满足绝大部分需求。标准流程不对 PACTFakeQuantize 和 _LearnableFakeQuantize 两种方法做详细说明,如果一定有需求,请在阅读相关论文后再使用。

# 基于统计的方法 from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize # https://arxiv.org/pdf/1805.06085 from horizon_plugin_pytorch.quantization.pact_fake_quantize import PACTFakeQuantize # https://arxiv.org/pdf/1902.08153 from horizon_plugin_pytorch.quantization._learnable_fake_quantize import _LearnableFakeQuantize

可以调用 FakeQuantize 的 with_args 方法得到构造器,并按上一节的代码示例用它构造 qconfig。with_args 的参数包括 FakeQuantize 和 observer 支持配置的参数,理论上可以配置所有 FakeQuantize 和 observer 类 init 方法声明中的参数,但为了屏蔽无关紧要的细节,我们只推荐您配置 observer 相关参数。

不同 observer 的参数不同,下面列出常用 observer 构造 FakeQuantize 的例子,其他 observer 的具体用法见校准章节。

import torch from horizon_plugin_pytorch.quantization.qconfig import QConfig from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver, FixedScaleObserver, MSEObserver from horizon_plugin_pytorch.dtype import qint8 # MinMaxObserver 的 __init__ 方法包含很多参数,with_args 方法可以控制这些参数。 # 我们只推荐您设置 fq_constructor_1 示例中的几个参数。 # def __init__( # self, # averaging_constant: float = 0.01, # ch_axis: int = -1, # dtype: Union[torch.dtype, QuantDType] = qint8, # qscheme: torch.qscheme = torch.per_tensor_symmetric, # quant_min: int = None, # quant_max: int = None, # is_sync_quantize: bool = False, # factory_kwargs: Dict = None, # ) -> None: fq_constructor_1 = FakeQuantize.with_args( observer=MinMaxObserver, # 适用于 qat 阶段的 input / output / weight 和 calibration 阶段的 weight。 averaging_constant=0.01, # calibration 后进行 qat 时,可将 input / output 的 averaging_constant 置为 0 以固定 scale。 dtype=qint8, # 量化类型,考虑算子的支持情况进行设置。 qscheme=torch.per_channel_symmetric, # 只有 weight 支持 per channel 量化。 ch_axis=0, # per channel 量化时指定 channel。 ) # 同理,您也可以查看 FixedScaleObserver 和 MSEObserver 的 __init__ 方法了解有哪些可以设置的参数。 fq_constructor_2 = FakeQuantize.with_args( observer=FixedScaleObserver, # 固定 scale,无论何种情况都不会变。 dtype=qint8, # 量化类型,考虑算子的支持情况进行设置。 scale=INPUT_ABS_MAX / 128, # 设定的 scale 值,一般设为绝对值最大值除以量化类型最大值。 ) fq_constructor_3 = FakeQuantize.with_args( observer=MSEObserver, # 适用于 calibration 阶段的 input / output。 dtype=qint8, # 量化类型,考虑算子的支持情况进行设置。 ) qconfig = QConfig( weight=fq_constructor_x, ... )

FakeCast 的定义

FakeCast 是伪转换节点,会将输入转换为 float32 类型,如果数据类型是 float16,那么还会在中间模拟转 float16 产生的截断误差,此节点主要用于标志需要浮点计算的算子。

使用 FakeCast 构造 qconfig 的方法与 FakeQuantize 类似,但只有 dtype 一个参数。

import torch from horizon_plugin_pytorch.quantization.qconfig import QConfig from horizon_plugin_pytorch.quantization.fake_cast import FakeCast qconfig = QConfig( input=FakeCast.with_args(dtype=torch.float16), # 考虑算子的支持情况进行设置。 ... )

构造 QConfig

  1. 按照上文介绍的方法,直接构造 QConfig 对象。这种方法比较灵活,可以配置任何可配置的参数,需要您对 QConfig 有一定的理解。

  2. 使用 get_qconfig 接口。此接口较直接构造 QConfig 对象的方法更简单易用,但不够灵活,高级用法和需求无法使用此接口实现。

import torch from horizon_plugin_pytorch.quantization import get_qconfig from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver from horizon_plugin_pytorch.quantization.qconfig import QConfig from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize from horizon_plugin_pytorch.dtype import qint8 # qconfig_1 / qconfig_2 / qconfig_3 / qconfig_4 等价。 qconfig_1 = QConfig( weight=FakeQuantize.with_args( observer=MinMaxObserver, averaging_constant=0.01, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, ), output=FakeQuantize.with_args( observer=MinMaxObserver, averaging_constant=0, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, ), ) qconfig_2 = QConfig( weight=FakeQuantize.with_args( observer=MinMaxObserver, qscheme=torch.per_channel_symmetric, ch_axis=0, ), output=FakeQuantize.with_args( observer=MinMaxObserver, averaging_constant=0, ), ) qconfig_3 = get_qconfig( observer=MinMaxObserver, # 输入输出 observer 类型,只支持 horizon_plugin_pytorch.quantization.observer_v2 中的 MinMaxObserver 和 MSEObserver,默认值为 MinMaxObserver。 in_dtype=None, # 输入数据类型,考虑算子的支持情况进行设置。None 表示 QConfig 的 input 关键字为 None,默认值为 None。 weight_dtype=qint8, # 权重数据类型,考虑算子的支持情况进行设置。None 表示 QConfig 的 weight 关键字为 None,默认值为 qint8。 out_dtype=qint8, # 输出数据类型,考虑算子的支持情况进行设置。None 表示 QConfig 的 output 关键字为 None,默认值为 qint8。 fix_scale=True, # 是否固定输入输出 scale。 ) qconfig_4 = get_qconfig(fix_scale=True)

使用方法

设置 QConfig 属性

直接设置 qconfig 属性。此方法优先级最高,其余方法不会覆盖直接设置的 qconfig。

model.qconfig = QConfig(...)

QConfig 模板

qconfig 模板基于 subclass trace 方案感知模型的图结构,并按设定的规则自动设置 qconfig,是我们最推荐的设置 qconfig 方法。使用模板需要在 prepare 接口上指定 qconfig setter 和 example_inputs。用法如下:

from horizon_plugin_pytorch.quantization import prepare from horizon_plugin_pytorch.quantization.qconfig_template import ( default_qat_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter ) qat_model = prepare( model, example_inputs=example_inputs, # 用来感知图结构 qconfig_setter=( # qconfig 模板,支持传入多个模板,优先级从高到低。 sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter(table, ratio=0.2), default_qat_qconfig_setter, ), )
注意

模板的优先级低于直接给模型设置 qconfig 属性,如果模型在 prepare 之前已经使用 model.qconfig = xxx 进行了配置,那么模板将不会生效。如果没有特殊需求,我们不推荐将两者混合使用,这很容易引发低级错误。绝大多数情况下,使用模板和 model.qconfig = xxx 两种设置方式中的一种即可满足需求。

模板可分为三类:

  1. 固定模板。固定模板中 calibration / qat / qat_fixed_act_scale 区别在于使用的 observer 类型和 scale 更新逻辑,分别用于校准,qat 训练,固定 activation scale qat 训练。default 模板 ( default_calibration_qconfig_setter / default_qat_qconfig_setter / default_qat_fixed_act_qconfig_setter ) 会做三件事:首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;然后,从 grid sample 算子的 grid 输入向前搜索,直到出现第一个 gemm 类算子或者 QuantStub,将中间的所有算子都设置为 int16。根据经验这里的 grid 一般表达范围较宽,int8 有较大可能不满足精度需求;最后,将其余算子设置为 int8。int16 模板 ( qat_8bit_weight_16bit_act_qconfig_setter / qat_8bit_weight_16bit_fixed_act_qconfig_setter / calibration_8bit_weight_16bit_act_qconfig_setter ) 会做两件事:首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;其次,将其余算子设置为 int16。
from horizon_plugin_pytorch.quantization.qconfig_template import ( default_calibration_qconfig_setter, default_qat_qconfig_setter, default_qat_fixed_act_qconfig_setter, qat_8bit_weight_16bit_act_qconfig_setter, qat_8bit_weight_16bit_fixed_act_qconfig_setter, calibration_8bit_weight_16bit_act_qconfig_setter, )
  1. 敏感度模板。敏感度模板有 sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,三者的区别和固定模板中三者的区别一致,也是分别用于校准,qat 训练,固定 activation scale qat 训练。 敏感度模板的第一个输入是精度 debug 工具产生的敏感度结果,第二个参数可以指定 ratio 或 topk,敏感度模板会将量化敏感度最高的 topk 个算子设置为 int16。搭配固定模板,可以轻松实现混合精度调优。若模型有多个输出,每个输出都会产生一个敏感度表,您可以设置多个敏感度模版。
from horizon_plugin_pytorch.quantization.qconfig_template import ( default_calibration_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter, sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter, ) table1 = torch.load("output_0-0_L1_sensitive_ops.pt") table2 = torch.load("output_0-1_L1_sensitive_ops.pt") calibration_model = prepare( model, example_inputs=example_input, qconfig_setter=( sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table1, ratio=0.2), sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table2, ratio=0.2), default_calibration_qconfig_setter, ), )
  1. 自定义模板。自定义模板只有 ModuleNameQconfigSetter,需要传入模块名和对应 qconfig 的字典,一般用于设置 fixed scale 等特殊需求,可以和固定模板,敏感度模板搭配使用。
from horizon_plugin_pytorch.quantization.qconfig_template import ( default_qat_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter, ModuleNameQconfigSetter, ) table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt") module_name_to_qconfig = { "op_1": get_qconfig(), # 自动替换生成的算子只能通过 ModuleNameQconfigSetter 配置自定义 qconfig。 "_generated_add_0": QConfig( output=FakeQuantize.with_args( observer=FixedScaleObserver, dtype=qint16, scale=OP2_MAX/QINT16_MAX, ) ), } qat_model = prepare( model, example_inputs=example_input, qconfig_setter=( ModuleNameQconfigSetter(module_name_to_qconfig), sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2), default_qat_qconfig_setter, ), )

QAT 各阶段 QConfig 的行为

本章节详细阐述 qconfig 在 QAT 各个阶段的行为,需要在深入了解地平线 QAT 工具的基础上阅读。

设置 QConfig 属性

设置完成后,设置的 module 会有一个 qconfig 属性。

算子替换

算子替换分为三种:

  1. 将 function 类算子替换为 module,例如:“ + ” 被替换为 “ generated_add_x ”。由于替换之前 function 是没有 qconfig 属性的,所以替换完成后也没有,qconfig 只能通过扩散机制从 parent module 继承或通过模板设置。

  2. 将 module 算子拆分为多个小的拼接算子,例如:“ horizon_plugin_pytorch.nn.Norm ” 会被拆分为 mul / sum / sqrt。如果 module 算子在拆分前已经有了 qconfig 属性,那么会调用算子的 propagate_qconfig 方法给小算子设置 qconfig。

def propagate_qconfig(self, qconfig): from horizon_plugin_pytorch.quantization.qconfig import ( promote_int8_activation_to_int16, ) int16_qconfig = promote_int8_activation_to_int16(qconfig) self.mul.qconfig = int16_qconfig self.sum.qconfig = int16_qconfig self.sqrt.qconfig = qconfig
  1. 1 和 2 的综合情况,先把 function 类算子替换为 module,再将 module 算子拆分为多个小的拼接算子,例如:“ torch.norm ” 会先被替换为 “ horizon_plugin_pytorch.nn.Norm ”,再拆分成小算子。

算子融合

  1. 在算子融合前,配置了融合模块的 qconfig 属性,融合后的模块会继承融合前最后一个模块的 qconfig。
if hasattr(op_list[-1], "qconfig"): fused.qconfig = op_list[-1].qconfig
  1. 如果想要直接设置融合后模块的 qconfig,可以通过模板进行设置。

QConfig 扩散

从 parent module 向 children module 扩散 qconfig。

  1. 深度优先遍历,为所有 children module 执行同样的操作。

  2. 对于当前遍历到的 module,当同时有 parent module 扩散下来的 qconfig 和自己本身的 qconfig 属性时,以自己的 qconfig 属性为准。在没有写 3 中的 propagate_qconfig 方法的前提下,parent module qconfig 只影响没有配置 qconfig 的 children module。

  3. 试图调用 propagate_qconfig 方法向 children module 扩散自己的 qconfig,propagate_qconfig 是按照以往经验写死的一些配置。当前的主要作用是给浮点阶段拆分的拼接算子设置小算子的 qconfig。

def _propagate_qconfig_helper( module, qconfig_dict, qconfig_parent=NotSetQConfig, prefix="", unused_qconfig_key=None, white_list: Optional[List[str]] = None, ): ... module_qconfig = getattr(module, "qconfig", module_qconfig) if module_qconfig is not NotSetQConfig: if hasattr(module, "propagate_qconfig"): module.propagate_qconfig(module_qconfig) module.qconfig = module_qconfig for name, child in module.named_children(): module_prefix = prefix + "." + name if prefix else name _propagate_qconfig_helper( child, qconfig_dict, module_qconfig, module_prefix, unused_qconfig_key, white_list, ) ...

QConfig Setter

由一系列 setter 组成,setter 的原则是不改变已有的 qconfig 属性,对于没有 qconfig 属性的 module,按照不同 setter 的规则自动设置 qconfig 属性。所以有多个 setter 时,前面的 setter 优先级更高。

因为 setter 处在算子替换和算子融合的后面,所以可以直接设置替换后和融合后算子的 qconfig。

QConfig 规范化

QConfig 规范化机制有两条,都是基于图去做的:

  1. 按照 march 获取硬件限制,将 qconfig 配置的一些 int16 的回退 int8。有一个特殊情况是 matmul 支持单 int16,在有敏感度 setter 的情况下,自动回退敏感度较小的分支,在没有敏感度 setter 的情况下,自动回退拓扑序靠后的分支。

  2. 把连续的量化节点缩减成一个。当前逻辑还比较简单,前一个 module 的输出量化节点 dtype 与后一个 module 的输入量化节点 dtype 一致时,去除后一个 module 的输入量化节点。

插入伪量化节点

调用 QAT 算子的 “ from_float ” 方法,将浮点算子替换为 QAT 算子,在 QAT 算子的 “ init ” 中根据 qconfig 初始化量化节点,在 QAT 算子的 forward 中调用量化节点。

class ConvBN2d(nn.Conv2d): def __init__(...): ... assert qconfig, "qconfig must be provided for QAT module" self.qconfig = qconfig if self.qconfig.activation is not None: self.activation_post_process = self.qconfig.activation() else: self.activation_post_process = None if self.qconfig.weight is None: raise ValueError("qconfig must include weight") self.weight_fake_quant = self.qconfig.weight( channel_len=self.out_channels ) def from_float(cls, mod): ... qat_conv_bn = cls( conv.in_channels, conv.out_channels, conv.kernel_size, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups, bias=conv.bias is not None, padding_mode=conv.padding_mode, qconfig=mod.qconfig, ) ... return qat_conv_bn def forward(self, input): ... out = self._conv_bn_forward(input) if self.activation_post_process is not None: return self.activation_post_process(out) else: return out