hmct.api.ORTExecutor

接口说明

HMCT提供的模型推理功能,用于支持对模型转换过程中产生的中间模型进行推理。

接口形式

class ORTExecutor(ORTExecutorBase): def __init__(self, model:ModelProto): def create_session(self): -> InferenceSession def get_support_devices(cls) -> list[str] def to(self, device:Union[str, list[str]]) -> None def get_inputs(self) -> list[str] def get_outputs(self) -> list[str] def inference(self, inputs:Dict[str, np.ndarray])->Dict[str, np.ndarray]

成员函数

成员函数详细说明返回值合法参数值取值范围
ORTExecutor.init(self, model: ModelProto)类初始化函数,传入一个onnx ModelProto对象一个onnx ModelProto对象
def create_session(self): -> InferenceSession创建用于推理的session用于推理的InferenceSession无参数
def get_support_devices(cls) -> list[str]获取当前ORTExecutor所有支持的device当前支持的device的string list无参数
def to(self, device:Union[str, list[str]]) -> None修改模型推理是运行device'cuda', 'cpu'或者二者组成的list
def get_inputs(self) -> list[str]

获取一个由模型输入的NodeArg类组成的list,NodeArg类有三个成员变量,name表示输入的名字,type表示输入的数据类型,shape表示输入的大小

class NodeArg: def __init__(self, name: str, type: int, shape: Sequence[Union[int, str]]): self.name = name self.type = type self.shape = shape
由模型输入的NodeArg类组成的list无参数
def get_outputs(self) -> list[str]获取一个由模型输出的NodeArg类组成的list,NodeArg类有三个成员变量,name表示输出的名字,type表示输出的数据类型,shape表示输出的大小,NodeArg定义同get_inputs中的描述由模型输出的NodeArg类组成的list无参数
def inference(self, inputs:Dict[str, np.ndarray])->Dict[str, np.ndarray]使用输入数据进行一次前向推理,获得推理结果并返回一个Dict,key是模型输出name str,value是输出结果np.ndarray

一个dict,key为输入name的字符串,value为这次推理的输入

{ 'input_name0': np.ndarray, 'input_name1': np.ndarray, ... }