HMCT提供的模型推理功能,用于支持对模型转换过程中产生的中间模型进行推理。
成员函数 | 详细说明 | 返回值 | 合法参数值取值范围 |
ORTExecutor.init(self, model: ModelProto) | 类初始化函数,传入一个onnx ModelProto对象 | 无 | 一个onnx ModelProto对象 |
def create_session(self): -> InferenceSession | 创建用于推理的session | 用于推理的InferenceSession | 无参数 |
def get_support_devices(cls) -> list[str] | 获取当前ORTExecutor所有支持的device | 当前支持的device的string list | 无参数 |
def to(self, device:Union[str, list[str]]) -> None | 修改模型推理是运行device | 无 | 'cuda', 'cpu'或者二者组成的list |
def get_inputs(self) -> list[str] | 获取一个由模型输入的NodeArg类组成的list,NodeArg类有三个成员变量,name表示输入的名字,type表示输入的数据类型,shape表示输入的大小 | 由模型输入的NodeArg类组成的list | 无参数 |
def get_outputs(self) -> list[str] | 获取一个由模型输出的NodeArg类组成的list,NodeArg类有三个成员变量,name表示输出的名字,type表示输出的数据类型,shape表示输出的大小,NodeArg定义同get_inputs中的描述 | 由模型输出的NodeArg类组成的list | 无参数 |
def inference(self, inputs:Dict[str, np.ndarray])->Dict[str, np.ndarray] | 使用输入数据进行一次前向推理,获得推理结果并返回 | 一个Dict,key是模型输出name str,value是输出结果np.ndarray | 一个dict,key为输入name的字符串,value为这次推理的输入 |