prepare 详解

prepare 的定义

prepare 是将浮点模型转换为伪量化模型的过程。这个过程会做以下几件事情:

  1. 算子替换:部分 torch function 类型的算子(例如 F.interpolate)在量化时需要插入伪量化节点,因此需要将算子替换为对应的 Module 类型实现(horizon_plugin_pytorch.nn.Interpolate),以将伪量化节点放在此 Module 内部。替换前后的模型是等价的。

  2. 算子融合:BPU 支持将特定的计算 pattern 进行融合,融合后算子中间结果用高精度表示,因此我们将被融合的多个算子替换为一个 Module,以阻止中间结果的量化。融合前后的模型也是等价的。

  3. 算子转换:将浮点算子替换为 qat 算子。按照设置的 qconfig,qat 算子会在输入/输出/权重处添加伪量化/伪转换节点。

注意

请确保 prepare 之后不会再修改模型,否则已经被替换的 qat 算子可能产生不符合预期的行为。例如:prepare 之后再将未融合的 bn 转为 sync bn 可能导致 qat bn 被再次修改为 sync bn,应该在prepare之前将它转为 sync bn。

  1. 模型结构检查:检查 qat 模型,生成检查结果文件。

prepare 接口的用法如下:

from horizon_plugin_pytorch.quantization.prepare import prepare, PrepareMethod from horizon_plugin_pytorch.quantization.qconfig_template import ( default_qat_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter, ) # 使用模板时必须提供 example_inputs 和 qconfig_setter。 # method 为 PrepareMethod.JIT_STRIP 或 PrepareMethod.JIT 时,必须提供 example_inputs。 # def prepare( # model: torch.nn.Module, # example_inputs: Any = None, # 用来感知图结构,确保可以用来跑通 forward。 # qconfig_setter: Optional[Union[Tuple[QconfigSetterBase, ...], QconfigSetterBase]] = None, # qconfig 模板,支持传入多个模板,优先级从高到低。 # method: PrepareMethod = PrepareMethod.JIT_STRIP, # prepare 模式 # ) -> torch.nn.Module: qat_model = prepare( float_model, example_inputs=example_inputs, qconfig_setter=( sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2), default_qat_qconfig_setter, ), method=PrepareMethod.JIT, )

PrepareMethod

prepare 有四种 method,他们的对比如下:

method原理优点缺点
Graph Mode使用 hook 和 subclass tensor 的方式感知图结构,在原有 forward 上做算子替换/算子融合等操作。全自动,代码修改少,屏蔽了很多细节问题,便于 debug。动态代码块需要特殊处理。
PrepareMethod.EAGER不感知图结构,算子替换/算子融合需手动进行。用法灵活,过程可控,便于 debug 和处理各类特殊需求。手动操作较多,代码修改多,上手成本高。

目前,JIT 和 JIT_STRIP 为我们推荐的 method,两者的区别在于 JIT_STRIP 会根据模型中 QuantStub 和 DequantStub 的位置识别并跳过前后处理,因此当模型中存在不需要量化的前后处理时,请使用 JIT_STRIP,否则它们将被量化,除此以外,两者完全一致。SYMBOLIC 和 EAGER 为早期方案,存在较多易用性问题,我们建议您不要使用这两种方案。

使用示例

import copy import numpy as np import torch from torch import nn from torch.nn import functional as F from torch.quantization import DeQuantStub, QuantStub from horizon_plugin_pytorch import March, set_march from horizon_plugin_pytorch.fx.jit_scheme import Tracer from horizon_plugin_pytorch.quantization import ( FakeQuantState, get_qconfig, PrepareMethod, prepare, set_fake_quantize, ) class Net(torch.nn.Module): def __init__(self, input_size, class_num) -> None: super().__init__() self.quant0 = QuantStub() self.quant1 = QuantStub() self.dequant = DeQuantStub() self.conv = nn.Conv2d(3, 3, 1) self.bn = nn.BatchNorm2d(3) self.classifier = nn.Conv2d(3, class_num, input_size) self.loss = nn.CrossEntropyLoss() def forward(self, input, other, target=None): # 不需要量化的前处理,使用 JIT_STRIP 时,将这些操作从计算图中剔除。 input = (input - 128) / 128.0 x = self.quant0(input) y = self.quant1(other) n = np.random.randint(1, 5) m = np.random.randint(1, 5) # 由于不重新生成 python code,此动态循环在 QAT 模型中保留。 for _ in range(n): for _ in range(m): # 动态循环中的代码块涉及到算子替换或算子融合时,必须进行标注。 # 标注的是需要算子替换或算子融合的逻辑,而不是 for 循环。 with Tracer.dynamic_block(self, "ConvBnAdd"): x = self.conv(x) x = self.bn(x) x = x + y x = self.classifier(x).squeeze() # 由于不重新生成 python code,此动态控制流在 QAT 模型中保留 if self.training: assert target is not None x = self.dequant(x) return F.cross_entropy(torch.softmax(x, dim=1), target) else: return torch.argmax(x, dim=1) model = Net(6, 2) train_example_input = ( torch.rand(2, 3, 6, 6) * 256, torch.rand(2, 3, 6, 6), torch.tensor([[0.0, 1.0], [1.0, 0.0]]), ) eval_example_input = train_example_input[:2] model.eval() set_march(March.NASH_E) model.qconfig = get_qconfig() qat_model = prepare( model, example_inputs=copy.deepcopy(eval_example_input), method=PrepareMethod.JIT_STRIP, ) qat_model.graph.print_tabular() # opcode name target args kwargs # ------------- ---------------- --------------------------------------------------------- -------------------------------- ---------- # placeholder input_0 input_0 () {} # call_module quant0 quant0 (input_0,) {} # placeholder input_1 input_1 () {} # call_module quant1 quant1 (input_1,) {} # call_module conv conv (quant0,) {} # call_module bn bn (conv,) {} # get_attr _generated_add_0 _generated_add_0 () {} # call_method add_2 add (_generated_add_0, bn, quant1) {} # scope_end 是在 trace 过程中自动插入的,用于标记子 module 或动态代码块的边界,不对应实际计算 # call_function scope_end <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {} # call_module conv_1 conv (add_2,) {} # call_module bn_1 bn (conv_1,) {} # get_attr _generated_add_1 _generated_add_0 () {} # call_method add_3 add (_generated_add_1, bn_1, quant1) {} # call_function scope_end_1 <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {} # call_module classifier classifier (add_3,) {} # call_function squeeze <method 'squeeze' of 'torch._C._TensorBase' objects> (classifier,) {} # call_function argmax <built-in method argmax of type object at 0x7f66f04cf820> (squeeze,) {'dim': 1} # call_function scope_end_2 <function Tracer.scope_end at 0x7f65d90e5e50> ('',) {} # output output output ((argmax,),) {} print(qat_model) # GraphModuleImpl( # (quant0): QuantStub( # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (quant1): QuantStub( # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (dequant): DeQuantStub() # (conv): Identity() # 由于 forward 代码不变,conv 和 bn 仍将被执行,所以融合后必须将 Conv 和 Bn 替换为 Identity # (bn): Identity() # (classifier): Conv2d( # 3, 2, kernel_size=(6, 6), stride=(1, 1) # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # (weight_fake_quant): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1.]), zero_point=tensor([0, 0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (loss): CrossEntropyLoss() # (_generated_add_0): ConvAdd2d( # 自动将 '+' 替换为 Module 形式,并将 Conv 和 Bn 融合进来 # 3, 3, kernel_size=(1, 1), stride=(1, 1) # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # (weight_fake_quant): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1., 1.]), zero_point=tensor([0, 0, 0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # ) qat_model.train() set_fake_quantize(qat_model, FakeQuantState.QAT) for _ in range(3): ret = qat_model(*train_example_input) ret.backward()
注意
  1. 动态代码块涉及到算子替换或算子融合时,必须使用 Tracer.dynamic_block 进行标注,否则将导致量化信息错乱或 forward 报错。
  2. 模型中调用次数变化的部分(子 module 或 dynamic_block),若在 trace 时仅执行了一次,则有可能和非动态部分产生算子融合,导致 forward 报错。

模型检查

在提供 example_inputs 的情况下,prepare 默认会对模型结构进行检查。如果检查完成,可以在运行目录下找到 model_check_result.txt 文件,如果检查失败,则需要根据警告提示修改模型或单独调用 horizon_plugin_pytorch.utils.check_model.check_qat_model 检查模型。检查流程和 debug 工具中的 check_qat_model 一致,结果文件的分析详见 check_qat_model 相关文档。