FX Quantization 原理介绍

阅读此文档前,建议先阅读 torch.fx — PyTorch documentation,以对 torch 的 FX 机制有初步的了解。

FX 采用符号执行的方式,可以在 nn.Module 或 function 的层面对模型建图,从而实现自动化的 fuse 以及其他基于图的优化。

量化流程

Fuse(可选)

FX 可以感知计算图,所以可以实现自动化的算子融合,您不再需要手动指定需要融合的算子,直接调用接口即可。

fused_model = horizon.quantization.fuse_fx(model)
  • 注意 fuse_fx 没有 inplace 参数,因为内部需要对模型做 symbolic trace 生成一个 GraphModule,所以无法做到 inplace 的修改。
  • fused_modelmodel 会共享几乎所有属性(包括子模块、算子等),因此在 fuse 之后请不要对 model 做任何修改,否则可能影响到 fused_model
  • 不必显式调用 fuse_fx 接口,因为后续的 prepare 接口内部集成了 fuse 的过程。

Prepare

在调用 prepare 接口之前必须根据目标硬件平台设置全局的 march。接口内部会先执行 fuse 过程(即使模型已经 fuse 过了),再将模型中符合条件的算子替换为 horizon.nn.qat 中的实现。

  • 可以根据需要选择合适的 qconfig(Calibtaion 或 QAT,注意两种 qconfig 不能混用)。
  • fuse_fx 类似,此接口不支持 inplace 参数,且在 prepare 之后请不要对输入的模型做任何修改。
horizon.march.set_march(horizon.march.March.NASH) qat_model = horizon.quantization.prepare( model, qconfig_setter = horizon.quantization.qconfig_template.default_qat_qconfig_setter, method = horizon.quantization.PrepareMethod.SYMBOLIC, )

Eager Mode 兼容性

大部分情况下,FX 量化的接口可以直接替换 eager mode 量化的接口,但是不能和 eager mode 的接口混用。部分模型在以下情况下需要对代码结构做一定的修改。

  • FX 不支持的操作:torch 的 symbolic trace 支持的操作是有限的,例如不支持将非静态变量作为判断条件、默认不支持 torch 以外的 pkg(如 numpy)等,且未执行到的条件分支将被丢弃。
  • 不想被 FX 处理的操作:如果模型的前后处理中使用了 torch 的 op,FX 在 trace 时会将他们视为模型的一部分,产生不符合预期的行为(例如将 torch 的某些 function 调用替换为 FloatFunctional)。

以上两种情况,都可以采用 wrap 的方法来避免,下面以 RetinaNet 为例进行说明。

from horizon_plugin_pytorch.fx.fx_helper import wrap as fx_wrap class RetinaNet(nn.Module): def __init__( self, backbone: nn.Module, neck: Optional[nn.Module] = None, head: Optional[nn.Module] = None, anchors: Optional[nn.Module] = None, targets: Optional[nn.Module] = None, post_process: Optional[nn.Module] = None, loss_cls: Optional[nn.Module] = None, loss_reg: Optional[nn.Module] = None, ): super(RetinaNet, self).__init__() self.backbone = backbone self.neck = neck self.head = head self.anchors = anchors self.targets = targets self.post_process = post_process self.loss_cls = loss_cls self.loss_reg = loss_reg def rearrange_head_out(self, inputs: List[torch.Tensor], num: int): outputs = [] for t in inputs: outputs.append(t.permute(0, 2, 3, 1).reshape(t.shape[0], -1, num)) return torch.cat(outputs, dim=1) def forward(self, data: Dict): feat = self.backbone(data["img"]) feat = self.neck(feat) if self.neck else feat cls_scores, bbox_preds = self.head(feat) if self.post_process is None: return cls_scores, bbox_preds # 将不需要建图的操作封装为一个 method 即可,FX 将不再关注 method 内部的逻辑, # 仅将它原样保留(method 中调用的 module 仍可被设置 qconfig,被 prepare 替换) return self._post_process( data, feat, cls_scores, bbox_preds) @fx_wrap() # fx_wrap 支持直接装饰 class method def _post_process(self, data, feat, cls_scores, bbox_preds) anchors = self.anchors(feat) # 对 self.training 的判断必须封装起来,否则在 symbolic trace 之后,此判断逻辑会被丢掉 if self.training: cls_scores = self.rearrange_head_out( cls_scores, self.head.num_classes ) bbox_preds = self.rearrange_head_out(bbox_preds, 4) gt_labels = [ torch.cat( [data["gt_bboxes"][i], data["gt_classes"][i][:, None] + 1], dim=-1, ) for i in range(len(data["gt_classes"])) ] gt_labels = [gt_label.float() for gt_label in gt_labels] _, labels = self.targets(anchors, gt_labels) avg_factor = labels["reg_label_mask"].sum() if avg_factor == 0: avg_factor += 1 cls_loss = self.loss_cls( pred=cls_scores.sigmoid(), target=labels["cls_label"], weight=labels["cls_label_mask"], avg_factor=avg_factor, ) reg_loss = self.loss_reg( pred=bbox_preds, target=labels["reg_label"], weight=labels["reg_label_mask"], avg_factor=avg_factor, ) return { "cls_loss": cls_loss, "reg_loss": reg_loss, } else: preds = self.post_process( anchors, cls_scores, bbox_preds, [torch.tensor(shape) for shape in data["resized_shape"]], ) assert ( "pred_bboxes" not in data.keys() ), "pred_bboxes has been in data.keys()" data["pred_bboxes"] = preds return data