prepare 是将浮点模型转换为伪量化模型的过程。这个过程会做以下几件事情:
算子替换:部分 torch function 类型的算子(例如 F.interpolate)在量化时需要插入伪量化节点,因此需要将算子替换为对应的 Module 类型实现(horizon_plugin_pytorch.nn.Interpolate),以将伪量化节点放在此 Module 内部。替换前后的模型是等价的。
算子融合:BPU 支持将特定的计算 pattern 进行融合,融合后算子中间结果用高精度表示,因此我们将被融合的多个算子替换为一个 Module,以阻止中间结果的量化。融合前后的模型也是等价的。
算子转换:将浮点算子替换为 qat 算子。按照设置的 qconfig,qat 算子会在输入/输出/权重处添加伪量化/伪转换节点。
请确保 prepare 之后不会再修改模型,否则已经被替换的 qat 算子可能产生不符合预期的行为。例如:prepare 之后再将未融合的 bn 转为 sync bn 可能导致 qat bn 被再次修改为 sync bn,应该在prepare之前将它转为 sync bn。
prepare 接口的用法如下:
from horizon_plugin_pytorch.quantization.prepare import prepare, PrepareMethod
from horizon_plugin_pytorch.quantization.qconfig_template import (
default_qat_qconfig_setter,
sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
)
# 使用模板时必须提供 example_inputs 和 qconfig_setter。
# method 为 PrepareMethod.JIT_STRIP 或 PrepareMethod.JIT 时,必须提供 example_inputs。
# def prepare(
# model: torch.nn.Module,
# example_inputs: Any = None, # 用来感知图结构,确保可以用来跑通 forward。
# qconfig_setter: Optional[Union[Tuple[QconfigSetterBase, ...], QconfigSetterBase]] = None, # qconfig 模板,支持传入多个模板,优先级从高到低。
# method: PrepareMethod = PrepareMethod.JIT_STRIP, # prepare 模式
# ) -> torch.nn.Module:
qat_model = prepare(
float_model,
example_inputs=example_inputs,
qconfig_setter=(
sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2),
default_qat_qconfig_setter,
),
method=PrepareMethod.JIT,
)
prepare 有四种 method,他们的对比如下:
method | 原理 | 优点 | 缺点 |
---|---|---|---|
Graph Mode | 使用 hook 和 subclass tensor 的方式感知图结构,在原有 forward 上做算子替换/算子融合等操作。 | 全自动,代码修改少,屏蔽了很多细节问题,便于 debug。 | 动态代码块需要特殊处理。 |
PrepareMethod.EAGER | 不感知图结构,算子替换/算子融合需手动进行。 | 用法灵活,过程可控,便于 debug 和处理各类特殊需求。 | 手动操作较多,代码修改多,上手成本高。 |
目前,JIT 和 JIT_STRIP 为我们推荐的 method,两者的区别在于 JIT_STRIP 会根据模型中 QuantStub 和 DequantStub 的位置识别并跳过前后处理,因此当模型中存在不需要量化的前后处理时,请使用 JIT_STRIP,否则它们将被量化,除此以外,两者完全一致。SYMBOLIC 和 EAGER 为早期方案,存在较多易用性问题,我们建议您不要使用这两种方案。
import copy
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.quantization import DeQuantStub, QuantStub
from horizon_plugin_pytorch import March, set_march
from horizon_plugin_pytorch.fx.jit_scheme import Tracer
from horizon_plugin_pytorch.quantization import (
FakeQuantState,
get_qconfig,
PrepareMethod,
prepare,
set_fake_quantize,
)
class Net(torch.nn.Module):
def __init__(self, input_size, class_num) -> None:
super().__init__()
self.quant0 = QuantStub()
self.quant1 = QuantStub()
self.dequant = DeQuantStub()
self.conv = nn.Conv2d(3, 3, 1)
self.bn = nn.BatchNorm2d(3)
self.classifier = nn.Conv2d(3, class_num, input_size)
self.loss = nn.CrossEntropyLoss()
def forward(self, input, other, target=None):
# 不需要量化的前处理,使用 JIT_STRIP 时,将这些操作从计算图中剔除。
input = (input - 128) / 128.0
x = self.quant0(input)
y = self.quant1(other)
n = np.random.randint(1, 5)
m = np.random.randint(1, 5)
# 由于不重新生成 python code,此动态循环在 QAT 模型中保留。
for _ in range(n):
for _ in range(m):
# 动态循环中的代码块涉及到算子替换或算子融合时,必须进行标注。
# 标注的是需要算子替换或算子融合的逻辑,而不是 for 循环。
with Tracer.dynamic_block(self, "ConvBnAdd"):
x = self.conv(x)
x = self.bn(x)
x = x + y
x = self.classifier(x).squeeze()
# 由于不重新生成 python code,此动态控制流在 QAT 模型中保留
if self.training:
assert target is not None
x = self.dequant(x)
return F.cross_entropy(torch.softmax(x, dim=1), target)
else:
return torch.argmax(x, dim=1)
model = Net(6, 2)
train_example_input = (
torch.rand(2, 3, 6, 6) * 256,
torch.rand(2, 3, 6, 6),
torch.tensor([[0.0, 1.0], [1.0, 0.0]]),
)
eval_example_input = train_example_input[:2]
model.eval()
set_march(March.NASH_E)
model.qconfig = get_qconfig()
qat_model = prepare(
model,
example_inputs=copy.deepcopy(eval_example_input),
method=PrepareMethod.JIT_STRIP,
)
qat_model.graph.print_tabular()
# opcode name target args kwargs
# ------------- ---------------- --------------------------------------------------------- -------------------------------- ----------
# placeholder input_0 input_0 () {}
# call_module quant0 quant0 (input_0,) {}
# placeholder input_1 input_1 () {}
# call_module quant1 quant1 (input_1,) {}
# call_module conv conv (quant0,) {}
# call_module bn bn (conv,) {}
# get_attr _generated_add_0 _generated_add_0 () {}
# call_method add_2 add (_generated_add_0, bn, quant1) {}
# scope_end 是在 trace 过程中自动插入的,用于标记子 module 或动态代码块的边界,不对应实际计算
# call_function scope_end <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {}
# call_module conv_1 conv (add_2,) {}
# call_module bn_1 bn (conv_1,) {}
# get_attr _generated_add_1 _generated_add_0 () {}
# call_method add_3 add (_generated_add_1, bn_1, quant1) {}
# call_function scope_end_1 <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {}
# call_module classifier classifier (add_3,) {}
# call_function squeeze <method 'squeeze' of 'torch._C._TensorBase' objects> (classifier,) {}
# call_function argmax <built-in method argmax of type object at 0x7f66f04cf820> (squeeze,) {'dim': 1}
# call_function scope_end_2 <function Tracer.scope_end at 0x7f65d90e5e50> ('',) {}
# output output output ((argmax,),) {}
print(qat_model)
# GraphModuleImpl(
# (quant0): QuantStub(
# (activation_post_process): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# )
# (quant1): QuantStub(
# (activation_post_process): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# )
# (dequant): DeQuantStub()
# (conv): Identity() # 由于 forward 代码不变,conv 和 bn 仍将被执行,所以融合后必须将 Conv 和 Bn 替换为 Identity
# (bn): Identity()
# (classifier): Conv2d(
# 3, 2, kernel_size=(6, 6), stride=(1, 1)
# (activation_post_process): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# (weight_fake_quant): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1.]), zero_point=tensor([0, 0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# )
# (loss): CrossEntropyLoss()
# (_generated_add_0): ConvAdd2d( # 自动将 '+' 替换为 Module 形式,并将 Conv 和 Bn 融合进来
# 3, 3, kernel_size=(1, 1), stride=(1, 1)
# (activation_post_process): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# (weight_fake_quant): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1., 1.]), zero_point=tensor([0, 0, 0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# )
# )
qat_model.train()
set_fake_quantize(qat_model, FakeQuantState.QAT)
for _ in range(3):
ret = qat_model(*train_example_input)
ret.backward()
在提供 example_inputs 的情况下,prepare 默认会对模型结构进行检查。如果检查完成,可以在运行目录下找到 model_check_result.txt 文件,如果检查失败,则需要根据警告提示修改模型或单独调用 horizon_plugin_pytorch.utils.check_model.check_qat_model 检查模型。检查流程和 debug 工具中的 check_qat_model 一致,结果文件的分析详见 check_qat_model 相关文档。