当前位置: 代码迷 >> 综合 >> tensorrt IPluginCreator实现
  详细解决方案

tensorrt IPluginCreator实现

热度:26   发布时间:2023-10-21 21:54:24.0

tensorrt IPluginCreator实现

  • 私有成员函数
  • PReLUPluginCreator()
  • virtual const char* getPluginName() const override
  • virtual const char* getPluginVersion() const override
  • virtual const nvinfer1::PluginFieldCollection* getFieldNames() override
  • virtual nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override
  • virtual nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLenth) override;
  • virtual void setPluginNamespace(const char* pluginNamespace) override {}
  • virtual const char* getPluginNamespace() const override;
  • 注册

私有成员函数

PluginFieldType的成员变量有name、data、type、size。
PluginFieldCollection包括append()、extend()、insert()、pop()函数,其中的操作对象都是PluginFieldType类型

struct PluginFieldCollection
{
    int nbFields;              //!< Number of PluginField entriesconst PluginField* fields; //!< Pointer to PluginField entries
};
 nvinfer1::PluginFieldCollection mFC;std::vector<nvinfer1::PluginField> mPluginAttributes;

PReLUPluginCreator()

将参数的name、data、type、size传入其中

PReLUPluginCreator::PReLUPluginCreator()  {
    mPluginAttributes.emplace_back(nvinfer1::PluginField("weights", nullptr, nvinfer1::PluginFieldType::kFLOAT32, 1));mPluginAttributes.emplace_back(nvinfer1::PluginField("nbWeight", nullptr, nvinfer1::PluginFieldType::kINT32, 1));mFC.nbFields = mPluginAttributes.size();mFC.fields = mPluginAttributes.data();
}

virtual const char* getPluginName() const override

const char* PReLUPluginCreator::getPluginName() const {
    return G_PRELU_NAME;
}

virtual const char* getPluginVersion() const override

const char* PReLUPluginCreator::getPluginVersion() const {
    return G_PLUGIN_VERSION;
}

virtual const nvinfer1::PluginFieldCollection* getFieldNames() override

返回需要被传入createPlugin的fields

const nvinfer1::PluginFieldCollection* PReLUPluginCreator::getFieldNames() {
    return &mFC;
}

virtual nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override

创造一个新的接口,根据name读取存储的值,并调用层的构造函数(为序列化的构造函数)

nvinfer1::IPluginV2* PReLUPluginCreator::createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) {
    int nbWeights;std::vector<float> weightValues;const nvinfer1::PluginField* fields = fc->fields;for (int i=0; i<fc->nbFields; i++) {
    const char* attrName = fields[i].name;if(strcmp(attrName, "nbWeights")) {
    ASSERT(fields[i].type == nvinfer1::PluginFieldType::kINT32);nbWeights = *(static_cast<const int*>(fields[i].data));}if(strcmp(attrName, "weights")) {
    ASSERT(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);weightValues.reserve(fields[i].length);const auto* w = static_cast<const float*>(fields[i].data);for (int j = 0; j < weightValues.size(); j++){
    weightValues.push_back(*w);w++;}}}nvinfer1::Weights weights{
    nvinfer1::DataType::kFLOAT, weightValues.data(), (int64_t)weightValues.size()};return new PReLUPlugin(&weights,nbWeights);
}

virtual nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLenth) override;

返回序列化层构造函数

nvinfer1::IPluginV2* PReLUPluginCreator::deserializePlugin(const char *layerName, const void *serialData, size_t serialLength) {
    return new PReLUPlugin(serialData, serialLength);
}

virtual void setPluginNamespace(const char* pluginNamespace) override {}

重写空值即可

virtual const char* getPluginNamespace() const override;

const char* PReLUPluginCreator::getPluginNamespace() const {
return G_PLUGIN_NAMESPACE;
}

注册

REGISTER_TENSORRT_PLUGIN(PReLUPluginCreator)