from horizon_plugin_pytorch.fx.fx_helper import wrap as fx_wrap
class RetinaNet(nn.Module):
def __init__(
self,
backbone: nn.Module,
neck: Optional[nn.Module] = None,
head: Optional[nn.Module] = None,
anchors: Optional[nn.Module] = None,
targets: Optional[nn.Module] = None,
post_process: Optional[nn.Module] = None,
loss_cls: Optional[nn.Module] = None,
loss_reg: Optional[nn.Module] = None,
):
super(RetinaNet, self).__init__()
self.backbone = backbone
self.neck = neck
self.head = head
self.anchors = anchors
self.targets = targets
self.post_process = post_process
self.loss_cls = loss_cls
self.loss_reg = loss_reg
def rearrange_head_out(self, inputs: List[torch.Tensor], num: int):
outputs = []
for t in inputs:
outputs.append(t.permute(0, 2, 3, 1).reshape(t.shape[0], -1, num))
return torch.cat(outputs, dim=1)
def forward(self, data: Dict):
feat = self.backbone(data["img"])
feat = self.neck(feat) if self.neck else feat
cls_scores, bbox_preds = self.head(feat)
if self.post_process is None:
return cls_scores, bbox_preds
# 将不需要建图的操作封装为一个 method 即可,FX 将不再关注 method 内部的逻辑,
# 仅将它原样保留(method 中调用的 module 仍可被设置 qconfig,被 prepare 替换)
return self._post_process( data, feat, cls_scores, bbox_preds)
@fx_wrap() # fx_wrap 支持直接装饰 class method
def _post_process(self, data, feat, cls_scores, bbox_preds)
anchors = self.anchors(feat)
# 对 self.training 的判断必须封装起来,否则在 symbolic trace 之后,此判断逻辑会被丢掉
if self.training:
cls_scores = self.rearrange_head_out(
cls_scores, self.head.num_classes
)
bbox_preds = self.rearrange_head_out(bbox_preds, 4)
gt_labels = [
torch.cat(
[data["gt_bboxes"][i], data["gt_classes"][i][:, None] + 1],
dim=-1,
)
for i in range(len(data["gt_classes"]))
]
gt_labels = [gt_label.float() for gt_label in gt_labels]
_, labels = self.targets(anchors, gt_labels)
avg_factor = labels["reg_label_mask"].sum()
if avg_factor == 0:
avg_factor += 1
cls_loss = self.loss_cls(
pred=cls_scores.sigmoid(),
target=labels["cls_label"],
weight=labels["cls_label_mask"],
avg_factor=avg_factor,
)
reg_loss = self.loss_reg(
pred=bbox_preds,
target=labels["reg_label"],
weight=labels["reg_label_mask"],
avg_factor=avg_factor,
)
return {
"cls_loss": cls_loss,
"reg_loss": reg_loss,
}
else:
preds = self.post_process(
anchors,
cls_scores,
bbox_preds,
[torch.tensor(shape) for shape in data["resized_shape"]],
)
assert (
"pred_bboxes" not in data.keys()
), "pred_bboxes has been in data.keys()"
data["pred_bboxes"] = preds
return data