HX study Computer Vision

ONNX-TensorRT

2021-01-23

ONNX-TensorRT

第一部分:ONNX-TensorRT工程

Onnx-tensorrt工程是用来将onnx模型转成tensorrt可用trtmodel的工程,其中包含了解析onnx op的代码,也可以根据需要添加自定义op。

当然如果没有自定义层之类的修改也可以直接使用tensorrt中nvonnxparser.lib解析。

nvonnxparser库概览

nvonnxparser库的核心代码文件见CMakeLists.txt文件,如下:

set(IMPORTER_SOURCES
  NvOnnxParser.cpp
  ModelImporter.cpp
  builtin_op_importers.cpp
  onnx2trt_utils.cpp
  ShapedWeights.cpp
  ShapeTensor.cpp
  LoopHelpers.cpp
  RNNHelpers.cpp
  OnnxAttrs.cpp
)

最终,这些代码被编译成动态链接库nvonnxparser.so和静态链接库nvonnxparser_static.a

add_library(nvonnxparser SHARED ${IMPORTER_SOURCES})
target_include_directories(nvonnxparser PUBLIC ${ONNX_INCLUDE_DIRS} ${TENSORRT_INCLUDE_DIR})
target_link_libraries(nvonnxparser PUBLIC onnx_proto ${PROTOBUF_LIBRARY} ${TENSORRT_LIBRARY})
add_library(nvonnxparser_static STATIC ${IMPORTER_SOURCES})
target_include_directories(nvonnxparser_static PUBLIC ${ONNX_INCLUDE_DIRS} ${TENSORRT_INCLUDE_DIR})
target_link_libraries(nvonnxparser_static PUBLIC onnx_proto ${PROTOBUF_LIBRARY} ${TENSORRT_LIBRARY})

解析流程解读

解析onnx文件流程,包含createParserparseFromFile两部分,对应以下两行代码,不熟悉tensorrt解析的可以先简单了解一下再回来看

nvonnxparser::createParser(*network, gLogger)

onnxParser->parseFromFile(source.onnxmodel().c_str(), 1)

createParser是最外层接口,定义在NvOnnxParser.h中,返回IParser

/** \brief 创建一个解析器对象
 *
 * \param network 解析器将写入的network
 * \param logger The logger to use
 * \return a new parser object or NULL if an error occurred
 * \see IParser
 */
#ifdef _MSC_VER
TENSORRTAPI IParser* createParser(nvinfer1::INetworkDefinition& network,
                                  nvinfer1::ILogger& logger)
#else
inline IParser* createParser(nvinfer1::INetworkDefinition& network,
                             nvinfer1::ILogger& logger)
#endif
{
    return static_cast<IParser*>(
        createNvOnnxParser_INTERNAL(&network, &logger, NV_ONNX_PARSER_VERSION));
}
/** \class IParser
 *
 * \brief 用于将ONNX模型解析为TensorRT网络定义的对象
 */
class IParser
{
public:
    /** 将序列化的ONNX模型解析到TensorRT网络中。这种方法的诊断价值非常有限。如果由于任何原因(例如不支持的IR版本、不支持的opset等)解析序列化模型失败,则用户有责任拦截并报告错误。到要获得更好的诊断,请使用下面的parseFromFile方法。
     */
    virtual bool parse(void const* serialized_onnx_model,
                       size_t serialized_onnx_model_size,
                       const char* model_path = nullptr)
        = 0;
    
    /** \brief 解析一个onnx模型文件,可以是一个二进制protobuf或者一个文本onnx模型调用里面的Parse方法
     */
    virtual bool parseFromFile(const char* onnxModelFile, int verbosity) = 0;

    /** \brief 检查TensorRT是否支持特定的ONNX模型
     */
    virtual bool supportsModel(void const* serialized_onnx_model,
                               size_t serialized_onnx_model_size,
                               SubGraphCollection_t& sub_graph_collection,
                               const char* model_path = nullptr)
        = 0;

    /** \brief 考虑到用户提供的权重,将序列化的ONNX模型解析到TensorRT网络中
     */
    virtual bool parseWithWeightDescriptors(
        void const* serialized_onnx_model, size_t serialized_onnx_model_size,
        uint32_t weight_count,
        onnxTensorDescriptorV1 const* weight_descriptors)
        = 0;

    /** \brief 返回解析器是否支持指定的运算符
     */
    virtual bool supportsOperator(const char* op_name) const = 0;
//...
    
//...
};

nvonnxparser::createParser函数通过return new onnx2trt::ModelImporter(network, logger),返回类ModelImporter,类ModelImporter继承IParser并重写了虚函数,。

class ModelImporter : public nvonnxparser::IParser
{
protected:
    string_map<NodeImporter> _op_importers;
    virtual Status importModel(::ONNX_NAMESPACE::ModelProto const& model, uint32_t weight_count,
        onnxTensorDescriptorV1 const* weight_descriptors);

private:
    ImporterContext _importer_ctx;
    RefitMap_t mRefitMap;
    std::list<::ONNX_NAMESPACE::ModelProto> _onnx_models; // Needed for ownership of weights
    int _current_node;
    std::vector<Status> _errors;
public:
    ModelImporter(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger)
        : _op_importers(getBuiltinOpImporterMap())
        , _importer_ctx(network, logger, &mRefitMap)
    {
    }
//...
    
//...
}

通过_op_importers(getBuiltinOpImporterMap())调用builtin_op_importers.h中的getBuiltinOpImporterMap()得到所有onnx注册的op,builtin_op_importers中所有的op,都将以DEFINE_BUILTIN_OP_IMPORTER形式出现,只要按照名字和版本注册了,那么当你加载onnx的时候,都会被认识

builtin_op_importers

  • onnxmodel到trtmodel的parse代码。从onnxmodel的input出发,最后,输出trtmodel的输出tensor_ptr;
  • onnx支持的builtin operators包括Conv, Argmax, Unsample,Relu等,具体可以参考operators.md文件;
  • 文件中根据onnx层的类型名调用相应的DEFINE_BUILTIN_OP_IMPORTER(Conv), DEFINE_BUILTIN_OP_IMPORTER(Argmax), DEFINE_BUILTIN_OP_IMPORTER(Unsample), DEFINE_BUILTIN_OP_IMPORTER(Relu)等,从而完成对应层的onnx2trtmodel的parser。

parseFromFile解析入口onnxParser->parseFromFile(source.onnxmodel().c_str(), 1),流程如下

调用ModelImporter::parseFromFile开始做解析

然后调用到ModelImporter::parse

然后是ModelImporter::parseWithWeightDescriptors

然后是ModelImporter::importModel

然后是ModelImporter::importInputs,这里ModelImporter::importInput是控制输入的,如果想对onnx的输入尺寸做修改,请修改里面的trt_dims即可

然后是ModelImporter::parseGraph,这里会调用getBuiltinOpImporterMap函数,获得builtin_op_importers所有自定义op

解析时查询op,调用(*importFunc),跳转到DEFINE_BUILTIN_OP_IMPORTER(op)

const string_map<NodeImporter>& opImporters = getBuiltinOpImporterMap();
//...
//...
// Dispatch to appropriate converter.
const NodeImporter* importFunc{nullptr};
if (opImporters.count(node.op_type()))
{
    importFunc = &opImporters.at(node.op_type());
}
else
{
    LOG_INFO("No importer registered for op: " << node.op_type() << ". Attempting to import as plugin.");
    importFunc = &opImporters.at("FallbackPluginImporter");
}
std::vector<TensorOrWeights> outputs;

GET_VALUE((*importFunc)(ctx, node, nodeInputs), &outputs);

这里importFunc类型是NodeImporter,定义的std::function,输入(ctx, node, nodeInputs)

typedef std::function<NodeImportResult(
    IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs)>
    NodeImporter;

DEFINE_BUILTIN_OP_IMPORTER(op)通过宏定义

#define DECLARE_BUILTIN_OP_IMPORTER(op)                                                                       \
    NodeImportResult import##op(                                                                              \
        IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs)

#define DEFINE_BUILTIN_OP_IMPORTER(op)                                                                        \
    NodeImportResult import##op(                                                                              \
        IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs);\
    static const bool op##_registered_builtin_op = registerBuiltinOpImporter(#op, import##op);                \
    IGNORE_UNUSED_GLOBAL(op##_registered_builtin_op);                                                         \
    NodeImportResult import##op(                                                                              \
        IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs)

主要完成以下三项工作:

1、将onnx输入数据转化为trt要求的数据格式

2、建立trt层,层定义参考Nvinfer.h

3、计算trt输出结果

第二部分:自定义op流程

DEFINE_BUILTIN_OP_IMPORTER

plugin层定义

官方plugin

NvInferRuntimeCommon.h

IPluginV2:用户实现层的插件类。插件是应用程序实现自定义层的机制。当与IPluginCreator结合使用时,它提供了一种在反序列化期间注册插件和查找插件注册表的机制。

IPluginV2Ext:此接口通过支持不同的输出数据类型和跨批处理的广播,为IPluginV2接口提供了额外的功能

IPluginV2IOExt:此接口通过扩展不同的I/O数据类型和张量格式,为IPluginV2Ext接口提供了额外的功能。

IPluginCreator:用户实现层的插件创建者类。

IPluginRegistry:所有插件的单一注册点,反序列化期间查找插件实现,pluginregistry只支持IPluginV2类型的插件,并且应该有一个相应的IPluginCreator实现。

来源参考

TensorRT

TensorRT学习(二)通过C++使用

TensorRT学习(三)通过自定义层扩展TensorRT

Onnx-tensorrt详解之nvonnxparser库

实现TensorRT自定义插件(plugin)自由!

TensorRT7实现插件流程

TensorRT的自定义算子Plugin的实现

TensorRT(4):NvInferRuntime.h接口头文件分析


上一篇 TensorRT C++调用

Comments

Content