tensorrt IPluginCreator实现
- 私有成员函数
- PReLUPluginCreator()
- virtual const char* getPluginName() const override
- virtual const char* getPluginVersion() const override
- virtual const nvinfer1::PluginFieldCollection* getFieldNames() override
- virtual nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override
- virtual nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLenth) override;
- virtual void setPluginNamespace(const char* pluginNamespace) override {}
- virtual const char* getPluginNamespace() const override;
- 注册
私有成员函数
PluginFieldType的成员变量有name、data、type、size。
PluginFieldCollection包括append()、extend()、insert()、pop()函数,其中的操作对象都是PluginFieldType类型
struct PluginFieldCollection
{
int nbFields; //!< Number of PluginField entriesconst PluginField* fields; //!< Pointer to PluginField entries
};
nvinfer1::PluginFieldCollection mFC;std::vector<nvinfer1::PluginField> mPluginAttributes;
PReLUPluginCreator()
将参数的name、data、type、size传入其中
PReLUPluginCreator::PReLUPluginCreator() {
mPluginAttributes.emplace_back(nvinfer1::PluginField("weights", nullptr, nvinfer1::PluginFieldType::kFLOAT32, 1));mPluginAttributes.emplace_back(nvinfer1::PluginField("nbWeight", nullptr, nvinfer1::PluginFieldType::kINT32, 1));mFC.nbFields = mPluginAttributes.size();mFC.fields = mPluginAttributes.data();
}
virtual const char* getPluginName() const override
const char* PReLUPluginCreator::getPluginName() const {
return G_PRELU_NAME;
}
virtual const char* getPluginVersion() const override
const char* PReLUPluginCreator::getPluginVersion() const {
return G_PLUGIN_VERSION;
}
virtual const nvinfer1::PluginFieldCollection* getFieldNames() override
返回需要被传入createPlugin的fields
const nvinfer1::PluginFieldCollection* PReLUPluginCreator::getFieldNames() {
return &mFC;
}
virtual nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override
创造一个新的接口,根据name读取存储的值,并调用层的构造函数(为序列化的构造函数)
nvinfer1::IPluginV2* PReLUPluginCreator::createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) {
int nbWeights;std::vector<float> weightValues;const nvinfer1::PluginField* fields = fc->fields;for (int i=0; i<fc->nbFields; i++) {
const char* attrName = fields[i].name;if(strcmp(attrName, "nbWeights")) {
ASSERT(fields[i].type == nvinfer1::PluginFieldType::kINT32);nbWeights = *(static_cast<const int*>(fields[i].data));}if(strcmp(attrName, "weights")) {
ASSERT(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);weightValues.reserve(fields[i].length);const auto* w = static_cast<const float*>(fields[i].data);for (int j = 0; j < weightValues.size(); j++){
weightValues.push_back(*w);w++;}}}nvinfer1::Weights weights{
nvinfer1::DataType::kFLOAT, weightValues.data(), (int64_t)weightValues.size()};return new PReLUPlugin(&weights,nbWeights);
}
virtual nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLenth) override;
返回序列化层构造函数
nvinfer1::IPluginV2* PReLUPluginCreator::deserializePlugin(const char *layerName, const void *serialData, size_t serialLength) {
return new PReLUPlugin(serialData, serialLength);
}
virtual void setPluginNamespace(const char* pluginNamespace) override {}
重写空值即可
virtual const char* getPluginNamespace() const override;
const char* PReLUPluginCreator::getPluginNamespace() const {
return G_PLUGIN_NAMESPACE;
}
注册
REGISTER_TENSORRT_PLUGIN(PReLUPluginCreator)