QConfig 详解
定义
QConfig 的定义
模型的量化方式由 qconfig 决定,在准备 qat / calibration 模型之前,需要先给模型设置 qconfig。
注意
因历史原因,Plugin 中有不同 qconfig 的定义和用法,早期版本的 qconfig 将在不久的将来被废弃,我们只推荐您使用此文档中介绍的 qconfig 用法。
一个 qconfig 对象可以设置 input / weight / output 三个关键字,分别表示算子输入/权重/输出的量化配置,prepare 模型时会根据这些配置决定是否要在对应位置插入 FakeQuantize / FakeCast 节点,None 表示不插入任何节点。
import torch
from horizon_plugin_pytorch.quantization.qconfig import QConfig
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
from horizon_plugin_pytorch.quantization.fake_cast import FakeCast
from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver
from horizon_plugin_pytorch.dtype import qint8
qconfig = QConfig(
input=None,
weight=FakeQuantize.with_args(
observer=MinMaxObserver,
dtype=qint8,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
),
output=FakeCast.with_args(dtype=torch.float16),
# activation=xxx 早期用法,作用与 output 关键字一致,当前仍兼容,但建议您使用 output 关键字。
)
FakeQuantize 的定义
FakeQuantize 是伪量化节点,会对输入进行量化反量化操作,插入伪量化可以在浮点模型的前向中模拟量化产生的误差。horizon_plugin_pytorch 支持 FakeQuantize / PACTFakeQuantize / _LearnableFakeQuantize 三种伪量化,我们只推荐您使用基于统计的 FakeQuantize,可以满足绝大部分需求。标准流程不对 PACTFakeQuantize 和 _LearnableFakeQuantize 两种方法做详细说明,如果一定有需求,请在阅读相关论文后再使用。
# 基于统计的方法
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
# https://arxiv.org/pdf/1805.06085
from horizon_plugin_pytorch.quantization.pact_fake_quantize import PACTFakeQuantize
# https://arxiv.org/pdf/1902.08153
from horizon_plugin_pytorch.quantization._learnable_fake_quantize import _LearnableFakeQuantize
可以调用 FakeQuantize 的 with_args 方法得到构造器,并按上一节的代码示例用它构造 qconfig。with_args 的参数包括 FakeQuantize 和 observer 支持配置的参数,理论上可以配置所有 FakeQuantize 和 observer 类 init 方法声明中的参数,但为了屏蔽无关紧要的细节,我们只推荐您配置 observer 相关参数。
不同 observer 的参数不同,下面列出常用 observer 构造 FakeQuantize 的例子,其他 observer 的具体用法见校准章节。
import torch
from horizon_plugin_pytorch.quantization.qconfig import QConfig
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver, FixedScaleObserver, MSEObserver
from horizon_plugin_pytorch.dtype import qint8
# MinMaxObserver 的 __init__ 方法包含很多参数,with_args 方法可以控制这些参数。
# 我们只推荐您设置 fq_constructor_1 示例中的几个参数。
# def __init__(
# self,
# averaging_constant: float = 0.01,
# ch_axis: int = -1,
# dtype: Union[torch.dtype, QuantDType] = qint8,
# qscheme: torch.qscheme = torch.per_tensor_symmetric,
# quant_min: int = None,
# quant_max: int = None,
# is_sync_quantize: bool = False,
# factory_kwargs: Dict = None,
# ) -> None:
fq_constructor_1 = FakeQuantize.with_args(
observer=MinMaxObserver, # 适用于 qat 阶段的 input / output / weight 和 calibration 阶段的 weight。
averaging_constant=0.01, # calibration 后进行 qat 时,可将 input / output 的 averaging_constant 置为 0 以固定 scale。
dtype=qint8, # 量化类型,考虑算子的支持情况进行设置。
qscheme=torch.per_channel_symmetric, # 只有 weight 支持 per channel 量化。
ch_axis=0, # per channel 量化时指定 channel。
)
# 同理,您也可以查看 FixedScaleObserver 和 MSEObserver 的 __init__ 方法了解有哪些可以设置的参数。
fq_constructor_2 = FakeQuantize.with_args(
observer=FixedScaleObserver, # 固定 scale,无论何种情况都不会变。
dtype=qint8, # 量化类型,考虑算子的支持情况进行设置。
scale=INPUT_ABS_MAX / 128, # 设定的 scale 值,一般设为绝对值最大值除以量化类型最大值。
)
fq_constructor_3 = FakeQuantize.with_args(
observer=MSEObserver, # 适用于 calibration 阶段的 input / output。
dtype=qint8, # 量化类型,考虑算子的支持情况进行设置。
)
qconfig = QConfig(
weight=fq_constructor_x,
...
)
FakeCast 的定义
FakeCast 是伪转换节点,会将输入转换为 float32 类型,如果数据类型是 float16,那么还会在中间模拟转 float16 产生的截断误差,此节点主要用于标志需要浮点计算的算子。
使用 FakeCast 构造 qconfig 的方法与 FakeQuantize 类似,但只有 dtype 一个参数。
import torch
from horizon_plugin_pytorch.quantization.qconfig import QConfig
from horizon_plugin_pytorch.quantization.fake_cast import FakeCast
qconfig = QConfig(
input=FakeCast.with_args(dtype=torch.float16), # 考虑算子的支持情况进行设置。
...
)
构造 QConfig
-
按照上文介绍的方法,直接构造 QConfig 对象。这种方法比较灵活,可以配置任何可配置的参数,需要您对 QConfig 有一定的理解。
-
使用 get_qconfig
接口。此接口较直接构造 QConfig 对象的方法更简单易用,但不够灵活,高级用法和需求无法使用此接口实现。
import torch
from horizon_plugin_pytorch.quantization import get_qconfig
from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver
from horizon_plugin_pytorch.quantization.qconfig import QConfig
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
from horizon_plugin_pytorch.dtype import qint8
# qconfig_1 / qconfig_2 / qconfig_3 / qconfig_4 等价。
qconfig_1 = QConfig(
weight=FakeQuantize.with_args(
observer=MinMaxObserver,
averaging_constant=0.01,
dtype=qint8,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
),
output=FakeQuantize.with_args(
observer=MinMaxObserver,
averaging_constant=0,
dtype=qint8,
qscheme=torch.per_tensor_symmetric,
ch_axis=-1,
),
)
qconfig_2 = QConfig(
weight=FakeQuantize.with_args(
observer=MinMaxObserver,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
),
output=FakeQuantize.with_args(
observer=MinMaxObserver,
averaging_constant=0,
),
)
qconfig_3 = get_qconfig(
observer=MinMaxObserver, # 输入输出 observer 类型,只支持 horizon_plugin_pytorch.quantization.observer_v2 中的 MinMaxObserver 和 MSEObserver,默认值为 MinMaxObserver。
in_dtype=None, # 输入数据类型,考虑算子的支持情况进行设置。None 表示 QConfig 的 input 关键字为 None,默认值为 None。
weight_dtype=qint8, # 权重数据类型,考虑算子的支持情况进行设置。None 表示 QConfig 的 weight 关键字为 None,默认值为 qint8。
out_dtype=qint8, # 输出数据类型,考虑算子的支持情况进行设置。None 表示 QConfig 的 output 关键字为 None,默认值为 qint8。
fix_scale=True, # 是否固定输入输出 scale。
)
qconfig_4 = get_qconfig(fix_scale=True)
使用方法
设置 QConfig 属性
直接设置 qconfig 属性。此方法优先级最高,其余方法不会覆盖直接设置的 qconfig。
model.qconfig = QConfig(...)
QConfig 模板
qconfig 模板基于 subclass trace 方案感知模型的图结构,并按设定的规则自动设置 qconfig,是我们最推荐的设置 qconfig 方法。使用模板需要在 prepare 接口上指定 qconfig setter 和 example_inputs。用法如下:
from horizon_plugin_pytorch.quantization import prepare
from horizon_plugin_pytorch.quantization.qconfig_template import (
default_qat_qconfig_setter,
sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter
)
qat_model = prepare(
model,
example_inputs=example_inputs, # 用来感知图结构
qconfig_setter=( # qconfig 模板,支持传入多个模板,优先级从高到低。
sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter(table, ratio=0.2),
default_qat_qconfig_setter,
),
)
注意
模板的优先级低于直接给模型设置 qconfig 属性,如果模型在 prepare 之前已经使用 model.qconfig = xxx 进行了配置,那么模板将不会生效。如果没有特殊需求,我们不推荐将两者混合使用,这很容易引发低级错误。绝大多数情况下,使用模板和 model.qconfig = xxx 两种设置方式中的一种即可满足需求。
模板可分为三类:
- 固定模板。固定模板中 calibration / qat / qat_fixed_act_scale 区别在于使用的 observer 类型和 scale 更新逻辑,分别用于校准,qat 训练,固定 activation scale qat 训练。default 模板 ( default_calibration_qconfig_setter / default_qat_qconfig_setter / default_qat_fixed_act_qconfig_setter ) 会做三件事:首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;然后,从 grid sample 算子的 grid 输入向前搜索,直到出现第一个 gemm 类算子或者 QuantStub,将中间的所有算子都设置为 int16。根据经验这里的 grid 一般表达范围较宽,int8 有较大可能不满足精度需求;最后,将其余算子设置为 int8。int16 模板 ( qat_8bit_weight_16bit_act_qconfig_setter / qat_8bit_weight_16bit_fixed_act_qconfig_setter / calibration_8bit_weight_16bit_act_qconfig_setter ) 会做两件事:首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;其次,将其余算子设置为 int16。
from horizon_plugin_pytorch.quantization.qconfig_template import (
default_calibration_qconfig_setter,
default_qat_qconfig_setter,
default_qat_fixed_act_qconfig_setter,
qat_8bit_weight_16bit_act_qconfig_setter,
qat_8bit_weight_16bit_fixed_act_qconfig_setter,
calibration_8bit_weight_16bit_act_qconfig_setter,
)
- 敏感度模板。敏感度模板有 sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,三者的区别和固定模板中三者的区别一致,也是分别用于校准,qat 训练,固定 activation scale qat 训练。
敏感度模板的第一个输入是精度 debug 工具产生的敏感度结果,第二个参数可以指定 ratio 或 topk,敏感度模板会将量化敏感度最高的 topk 个算子设置为 int16。搭配固定模板,可以轻松实现混合精度调优。若模型有多个输出,每个输出都会产生一个敏感度表,您可以设置多个敏感度模版。
from horizon_plugin_pytorch.quantization.qconfig_template import (
default_calibration_qconfig_setter,
sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,
sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,
)
table1 = torch.load("output_0-0_L1_sensitive_ops.pt")
table2 = torch.load("output_0-1_L1_sensitive_ops.pt")
calibration_model = prepare(
model,
example_inputs=example_input,
qconfig_setter=(
sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table1, ratio=0.2),
sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table2, ratio=0.2),
default_calibration_qconfig_setter,
),
)
- 自定义模板。自定义模板只有 ModuleNameQconfigSetter,需要传入模块名和对应 qconfig 的字典,一般用于设置 fixed scale 等特殊需求,可以和固定模板,敏感度模板搭配使用。
from horizon_plugin_pytorch.quantization.qconfig_template import (
default_qat_qconfig_setter,
sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
ModuleNameQconfigSetter,
)
table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt")
module_name_to_qconfig = {
"op_1": get_qconfig(),
# 自动替换生成的算子只能通过 ModuleNameQconfigSetter 配置自定义 qconfig。
"_generated_add_0": QConfig(
output=FakeQuantize.with_args(
observer=FixedScaleObserver,
dtype=qint16,
scale=OP2_MAX/QINT16_MAX,
)
),
}
qat_model = prepare(
model,
example_inputs=example_input,
qconfig_setter=(
ModuleNameQconfigSetter(module_name_to_qconfig),
sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2),
default_qat_qconfig_setter,
),
)
QAT 各阶段 QConfig 的行为
本章节详细阐述 qconfig 在 QAT 各个阶段的行为,需要在深入了解地平线 QAT 工具的基础上阅读。
设置 QConfig 属性
设置完成后,设置的 module 会有一个 qconfig 属性。
算子替换
算子替换分为三种:
-
将 function 类算子替换为 module,例如:“ + ” 被替换为 “ generated_add_x ”。由于替换之前 function 是没有 qconfig 属性的,所以替换完成后也没有,qconfig 只能通过扩散机制从 parent module 继承或通过模板设置。
-
将 module 算子拆分为多个小的拼接算子,例如:“ horizon_plugin_pytorch.nn.Norm ” 会被拆分为 mul / sum / sqrt。如果 module 算子在拆分前已经有了 qconfig 属性,那么会调用算子的 propagate_qconfig 方法给小算子设置 qconfig。
def propagate_qconfig(self, qconfig):
from horizon_plugin_pytorch.quantization.qconfig import (
promote_int8_activation_to_int16,
)
int16_qconfig = promote_int8_activation_to_int16(qconfig)
self.mul.qconfig = int16_qconfig
self.sum.qconfig = int16_qconfig
self.sqrt.qconfig = qconfig
- 1 和 2 的综合情况,先把 function 类算子替换为 module,再将 module 算子拆分为多个小的拼接算子,例如:“ torch.norm ” 会先被替换为 “ horizon_plugin_pytorch.nn.Norm ”,再拆分成小算子。
算子融合
- 在算子融合前,配置了融合模块的 qconfig 属性,融合后的模块会继承融合前最后一个模块的 qconfig。
if hasattr(op_list[-1], "qconfig"):
fused.qconfig = op_list[-1].qconfig
- 如果想要直接设置融合后模块的 qconfig,可以通过模板进行设置。
QConfig 扩散
从 parent module 向 children module 扩散 qconfig。
-
深度优先遍历,为所有 children module 执行同样的操作。
-
对于当前遍历到的 module,当同时有 parent module 扩散下来的 qconfig 和自己本身的 qconfig 属性时,以自己的 qconfig 属性为准。在没有写 3 中的 propagate_qconfig 方法的前提下,parent module qconfig 只影响没有配置 qconfig 的 children module。
-
试图调用 propagate_qconfig 方法向 children module 扩散自己的 qconfig,propagate_qconfig 是按照以往经验写死的一些配置。当前的主要作用是给浮点阶段拆分的拼接算子设置小算子的 qconfig。
def _propagate_qconfig_helper(
module,
qconfig_dict,
qconfig_parent=NotSetQConfig,
prefix="",
unused_qconfig_key=None,
white_list: Optional[List[str]] = None,
):
...
module_qconfig = getattr(module, "qconfig", module_qconfig)
if module_qconfig is not NotSetQConfig:
if hasattr(module, "propagate_qconfig"):
module.propagate_qconfig(module_qconfig)
module.qconfig = module_qconfig
for name, child in module.named_children():
module_prefix = prefix + "." + name if prefix else name
_propagate_qconfig_helper(
child,
qconfig_dict,
module_qconfig,
module_prefix,
unused_qconfig_key,
white_list,
)
...
QConfig Setter
由一系列 setter 组成,setter 的原则是不改变已有的 qconfig 属性,对于没有 qconfig 属性的 module,按照不同 setter 的规则自动设置 qconfig 属性。所以有多个 setter 时,前面的 setter 优先级更高。
因为 setter 处在算子替换和算子融合的后面,所以可以直接设置替换后和融合后算子的 qconfig。
QConfig 规范化
QConfig 规范化机制有两条,都是基于图去做的:
-
按照 march 获取硬件限制,将 qconfig 配置的一些 int16 的回退 int8。有一个特殊情况是 matmul 支持单 int16,在有敏感度 setter 的情况下,自动回退敏感度较小的分支,在没有敏感度 setter 的情况下,自动回退拓扑序靠后的分支。
-
把连续的量化节点缩减成一个。当前逻辑还比较简单,前一个 module 的输出量化节点 dtype 与后一个 module 的输入量化节点 dtype 一致时,去除后一个 module 的输入量化节点。
插入伪量化节点
调用 QAT 算子的 “ from_float ” 方法,将浮点算子替换为 QAT 算子,在 QAT 算子的 “ init ” 中根据 qconfig 初始化量化节点,在 QAT 算子的 forward 中调用量化节点。
class ConvBN2d(nn.Conv2d):
def __init__(...):
...
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig
if self.qconfig.activation is not None:
self.activation_post_process = self.qconfig.activation()
else:
self.activation_post_process = None
if self.qconfig.weight is None:
raise ValueError("qconfig must include weight")
self.weight_fake_quant = self.qconfig.weight(
channel_len=self.out_channels
)
def from_float(cls, mod):
...
qat_conv_bn = cls(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=conv.bias is not None,
padding_mode=conv.padding_mode,
qconfig=mod.qconfig,
)
...
return qat_conv_bn
def forward(self, input):
...
out = self._conv_bn_forward(input)
if self.activation_post_process is not None:
return self.activation_post_process(out)
else:
return out