量化感知训练指南
量化感知训练通过在模型中插入一些伪量化节点,从而使得通过量化感知训练得到的模型转换成定点模型时尽可能减少精度损失。
量化感知训练和传统的模型训练无异,可以从零开始,搭建一个伪量化模型,然后对该伪量化模型进行训练。
由于部署的硬件平台有诸多限制,搞清这些限制,并且根据这些限制搭建伪量化模型门槛较高。
量化感知训练工具通过在您提供的浮点模型上根据部署平台的限制自动插入伪量化量化算子的方法,降低开发量化模型的门槛。
量化感知训练由于施加了各种限制,因此,一般来说,量化感知训练比纯浮点模型的训练更加困难。量化感知训练工具的目标是降低量化感知训练的难度,降低量化模型部署的工程难度。
流程和示例
虽然量化感知训练工具不强制要求从一个预训练的浮点模型开始,但是,经验表明,通常从预训练的高精度浮点模型开始量化感知训练能大大降低量化感知训练的难度。
# 将模型转为 QAT 状态
qat_model = prepare(
float_model,
example_input,
qconfig_setter = horizon.quantization.qconfig_template.default_qat_qconfig_setter,
).to(device)
# 加载 Calibration 模型中的量化参数
qat_model.load_state_dict(calib_model.state_dict())
# 进行量化感知训练
# 作为一个 filetune 过程,量化感知训练一般需要设定较小的学习率
optimizer = torch.optim.SGD(
qat_model.parameters(), lr=0.0001, weight_decay=2e-4
)
for nepoch in range(epoch_num):
# 注意此处对 QAT 模型 training 状态的控制方法
qat_model.train()
set_fake_quantize(qat_model, FakeQuantState.QAT)
train_one_epoch(
qat_model,
nn.CrossEntropyLoss(),
optimizer,
None,
train_data_loader,
device,
)
# 注意此处对 QAT 模型 eval 状态的控制方法
qat_model.eval()
set_fake_quantize(qat_model, FakeQuantState.VALIDATION)
# 测试 qat 模型精度
top1, top5 = evaluate(
qat_model,
eval_data_loader,
device,
)
print(
"QAT model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
top1.avg, top5.avg
)
)
# 测试 quantized 模型精度
qat_hbir_model = horizon_plugin_pytorch.quantization.hbdk4.export(
qat_model. example_input
)
quantized_hbir_model = hbdk4.compiler.convert(qat_hbir_model)
top1, top5 = evaluate(
quantized_hbir_model,
eval_data_loader,
)
print(
"Quantized model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
top1.avg, top5.avg
)
)
注意
由于部署平台的底层限制,QAT 模型无法完全代表最终上板精度,请务必监控 quantized 模型精度,确保 quantized 模型精度正常,否则可能出现模型上板掉点问题。
由上述示例代码可以看到,与传统的纯浮点模型训练相比,量化感知训练多了两个步骤:
- prepare。
- 加载 Calibration 模型参数。
prepare
这一步骤的目标是对浮点网络进行变换,插入伪量化节点。
加载 Calibration 模型参数
通过加载 Calibration 得到的伪量化参数,来获得一个较好的初始化。
注意
算子的兼容性依赖 torch.nn.Module 的 _version 变量实现,_version 会在 state_dict._metadata 中保存,请确保在保存或加载 state_dict 的过程中保留了 _metadata,否则可能引起兼容性问题。
对于 state_dict 的修改不应当只关注 key 和 value,同样需要考虑 _metadata。如下是一个复制 state_dict 的例子:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k] = v
if hasattr(state_dict, "_metadata"):
new_state_dict._metadata = copy.deepcopy(state_dict._metadata)
训练迭代
至此,完成了伪量化模型的搭建和参数的初始化,然后就可以进行常规的训练迭代和模型参数更新,并且监控 quantized 模型精度。
伪量化算子
量化感知训练和传统的浮点模型的训练主要区别在于插入了伪量化算子,并且,不同量化感知训练算法也是通过伪量化算子来体现的,因此,这里介绍一下伪量化算子。
注解
由于 BPU 只支持对称量化,因此,这里以对称量化为例介绍。
伪量化过程
以 int8 量化感知训练为例,一般来说,伪量化算子的计算过程如下:
fake_quant_x = clip(round(x / scale),-128, 127) * scale
和 Conv2d 通过训练来优化 weight, bias 参数类似,伪量化算子要通过训练来优化 scale 参数。
然而,由于 round 作为阶梯函数,其梯度为 0,从而导致了伪量化算子无法直接通过梯度反向传播的方式进行训练。
解决这一问题,通常有两种方案:基于统计的方法和基于学习的方法。
基于统计的方法
量化的目标是把 Tensor 中的浮点数通过 scale 参数均匀地映射到 int8 表示的 [-128, 127] 的范围上。既然是均匀映射,那么很容易得到 scale 的计算方法:
def compute_scale(x: Tensor):
xmin, xmax = x.max(), maxv = x.min()
return max(xmin.abs(), xmax.abs()) / 256.0
由于 Tensor 中数据分布不均匀以及外点问题,又衍生了不同的计算 xmin 和 xmax 的方法。可以参考 MinMaxObserver
等。
在工具中的使用方法请参考 QConfig 详解。
基于学习的方法
虽然 round 的梯度为 0,研究者通过实验发现,在该场景下,如果直接设置其梯度为 1 也可以使得模型收敛到预期的精度。
def round_ste(x: Tensor):
return (x.round() - x).detach() + x
在工具中的使用方法请参考 FakeQuantize 的定义。
如您有兴趣进一步了解,可以参考如下论文:Learned Step Size Quantization。