算子融合

训练工具支持的算子融合可分为两大类:1. 吸收 BN;2. 融合 Add、ReLU(6)。

吸收 BN

吸收 BN 的目的是为了减少模型的计算量。因为 BN 是线性变换过程,因此,当 BNConv 一起出现的时候,可以把 BN 的参数吸收到 Conv 的参数中,从而在部署的模型中消除 BN 的计算。

吸收的计算过程如下:

fuse_bn

通过吸收 BN ,可以把 Conv2d + BN2d 简化为 Conv2d

absorb_bn

融合 Add、ReLU(6)

和 CUDA Kernel Fusion 中将 CUDA Kernel 融合以提高计算速度不同,训练工具支持的融合更加偏重量化层面。

BPU 硬件针对常见的模型基本结构做了优化,在计算 Conv -> Add -> ReLU 这种算子组合时,可使算子间的数据传递保留高精度的状态,提高模型整体的数值精度。因此在对模型进行量化时,我们可以将 Conv -> Add -> ReLU 视为一个整体。

由于训练工具对模型进行量化改造时以 torch.nn.Module 为单位,为了在量化时将 Conv -> Add -> ReLU 视为一个整体,需要将它们合并为一个 Module

算子融合除了可以使中间结果保留高精度状态之外,也可以省去将中间结果转化为低精度表示的过程,因此执行速度和不融合相比也会更快。

由于算子融合既可以提高模型精度,又可以提高模型速度,一般应该对所有可融合的部分进行融合。

实现原理

得益于 FX 可以获取计算图的优势,训练工具可以自动化地对模型的计算图进行分析,根据预定义的 fusion pattern 对可融合部分进行匹配,并通过 submodule 替换实现融合的操作。下面举例进行说明:

吸收 BN 和融合 Add、ReLU(6) 可以通过相同的机制完成,因此在融合时不需要进行区分。

import torch from torch import nn from torch.quantization import DeQuantStub from horizon_plugin_pytorch.quantization import QuantStub from horizon_plugin_pytorch.quantization import fuse_fx class ModelForFusion(torch.nn.Module): def __init__( self, ): super(ModelForFusion, self).__init__() self.quantx = QuantStub() self.quanty = QuantStub() self.conv = nn.Conv2d(3, 3, 3) self.bn = nn.BatchNorm2d(3) self.relu = nn.ReLU() self.dequant = DeQuantStub() def forward(self, x, y): x = self.quantx(x) y = self.quanty(y) x = self.conv(x) x = self.bn(x) x = x + y x = self.relu(x) x = self.dequant(x) return x float_model = ModelForFusion() fused_model = fuse_fx(float_model) print(fused_model) """ ModelForFusion( (quantx): QuantStub() (quanty): QuantStub() (conv): Identity() (bn): Identity() (relu): Identity() (dequant): DeQuantStub() (_generated_add_0): ConvAddReLU2d( (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (relu): ReLU() ) ) def forward(self, x, y): quantx = self.quantx(x); x = None quanty = self.quanty(y); y = None _generated_add_0 = self._generated_add_0 add_1 = self._generated_add_0(quantx, quanty); quantx = quanty = None dequant = self.dequant(add_1); add_1 = None return dequant """

可以看到,对模型执行算子融合操作后,BN 被吸收进 Conv 中,且 Conv、Add、ReLU 被融合进一个 Module 中(_generated_add_0)。原本的 submodule 被替换为 Identity,且不在 forward 代码中调用。

FX 自动地将模型中 x = x + y 的加号替换为了名为 _generated_add_0Module 形式,以支持算子融合和量化的相关操作。

可以融合的算子

目前支持的可融合的算子组合见以下函数定义:

import operator import torch from torch import nn from horizon_plugin_pytorch import nn as horizon_nn def register_fusion_patterns(): convs = ( nn.Conv2d, nn.ConvTranspose2d, nn.Conv3d, nn.Linear, ) bns = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) adds = ( nn.quantized.FloatFunctional.add, horizon_nn.quantized.FloatFunctional.add, torch.add, operator.add, # 即代码中使用的加号 ) relus = (nn.ReLU, nn.ReLU6, nn.functional.relu, nn.functional.relu6) for conv in convs: for bn in bns: for add in adds: for relu in relus: # conv bn register_fusion_pattern((bn, conv))(ConvBNAddReLUFusion) # conv relu register_fusion_pattern((relu, conv))(ConvBNAddReLUFusion) # conv add register_fusion_pattern((add, conv, MatchAllNode))( ConvBNAddReLUFusion ) # conv 的输出作为 add 的第一个输入 register_fusion_pattern((add, MatchAllNode, conv))( ConvBNAddedReLUFusion ) # conv 的输出作为 add 的第二个输入 # conv bn relu register_fusion_pattern((relu, (bn, conv)))( ConvBNAddReLUFusion ) # conv bn add register_fusion_pattern((add, (bn, conv), MatchAllNode))( ConvBNAddReLUFusion ) register_fusion_pattern((add, MatchAllNode, (bn, conv)))( ConvBNAddedReLUFusion ) # conv add relu register_fusion_pattern((relu, (add, conv, MatchAllNode)))( ConvBNAddReLUFusion ) register_fusion_pattern((relu, (add, MatchAllNode, conv)))( ConvBNAddedReLUFusion ) # conv bn add relu register_fusion_pattern( (relu, (add, (bn, conv), MatchAllNode)) )(ConvBNAddReLUFusion) register_fusion_pattern( (relu, (add, MatchAllNode, (bn, conv))) )(ConvBNAddedReLUFusion)