vpt_fpn_plugin_factory.h 10.8 KB
#ifndef __VPT_FPN_PLUGIN_FACTORY_H__
#define __VPT_FPN_PLUGIN_FACTORY_H__
#include <iostream>
#include <memory>
#include "NvInfer.h"
#include "NvCaffeParser.h"
#include "NvInferPlugin.h"
#include "utools.h"
//#include "caffe2/tensorrt/plugins/caffe_proposal_with_anchor_plugin.h"
//#include "caffe2/tensorrt/plugins/caffe_collect_distribute_fpn_rpn_proposals_plugin.h"
//#include "caffe2/tensorrt/plugins/caffe_roi_align_plugin.h"
//#include "caffe2/tensorrt/plugins/caffe_fpn_concat_plugin.h"
//#include "caffe2/tensorrt/plugins/caffe_batch_reindex_plugin.h"
////#include "caffe2/tensorrt/plugins/caffe_generate_proposals_plugin.h"
//#include "caffe2/tensorrt/plugins/caffe_bbox_transform_plugin.h"

using namespace nvinfer1;
using namespace nvcaffeparser1;
using namespace nvinfer1::plugin;
//using namespace caffe2;


// integration for serialization
class VPT_FPNPluginFactory : public nvcaffeparser1::IPluginFactory, public nvinfer1::IPluginFactory
{
public:
	// caffe parser plugin implementation
	bool isPlugin(const char* name) override
	{
		//printf("%s<%d>:is_plugin->%s\n", __FILE__, __LINE__, name);
		return (!strcmp(name, "ColDisFpnRpnProposals")
			|| !strcmp(name, "proposal_fpn3")
			|| !strcmp(name, "proposal_fpn4")
			|| !strcmp(name, "proposal_fpn5")
			|| !strcmp(name, "proposal_fpn6")
			|| !strcmp(name, "roi_align_fpn3")
			|| !strcmp(name, "roi_align_fpn4")
			|| !strcmp(name, "roi_align_fpn5")
			|| !strcmp(name, "concat_roi_feat")
			|| !strcmp(name, "batchreindex")
			|| !strcmp(name, "bbox_transform")
			);
	}

	void(*nvPluginDeleter)(INvPlugin*) { 
		[](INvPlugin* ptr)
		{
			printf("%s<%d>:destory nvplugin.\n", __FILE__, __LINE__);
			ptr->destroy();
		}
	};

//#define USE_CaffeProposalWithAnchorPlugin
	virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const nvinfer1::Weights* weights, int nbWeights) override
	{
		//printf("%s<%d>:create_plugin->%s\n", __FILE__, __LINE__, layerName);
		// there's no way to pass parameters through from the model definition, so we have to define it here explicitly
		assert(isPlugin(layerName));
		std::shared_ptr<IPlugin> cur_plugin{ nullptr };

		if (!strcmp(layerName, "ColDisFpnRpnProposals"))
		{
			float roi_canonical_scale = 224;
			int roi_canonical_level = 4;
			int roi_max_level = 5;
			int roi_min_level = 3;
			int rpn_max_level = 6;
			int rpn_min_level = 3;
			int rpn_post_nms_topN = 1000;
			//cur_plugin = std::shared_ptr<IPlugin>(new CaffeCollectDistributeFpnRpnProposalsPlugin(
			//	roi_canonical_scale, roi_canonical_level,
			//	roi_max_level, roi_min_level, rpn_max_level,
			//	rpn_min_level, rpn_post_nms_topN
			//	));
			cur_plugin = ctools_create_caffe_collect_distribute_fpn_rpn_proposals_plugin(
				roi_canonical_scale, roi_canonical_level,
				roi_max_level, roi_min_level, rpn_max_level,
				rpn_min_level, rpn_post_nms_topN
			);
		}
		else if (!strcmp(layerName, "proposal_fpn3"))
		{
			int feat_stride = 8;
			int pre_nms_topN = 1000;
			int post_nms_topN = 1000;
			float nms_thresh = 0.7;
			int min_size = 0;
			std::vector<float> anchor_scale({ 4.0f });
			std::vector<float> anchor_ratio({ 0.5f, 1.0f, 2.0f });
#ifdef USE_CaffeProposalWithAnchorPlugin
			cur_plugin = ctools_create_caffe_proposal_with_anchor_plugin(feat_stride, pre_nms_topN, post_nms_topN, nms_thresh, min_size, anchor_scale, anchor_ratio);
#else
			bool correct_transform_coords = 1;
			bool angle_bound_on = 1;
			int angle_bound_lo = -90;
			int angle_bound_hi = 90;
			float clip_angle_thresh = 1.0;
			cur_plugin = ctools_create_caffe_generate_proposals_plugin(1.0f / feat_stride, feat_stride, pre_nms_topN,
				post_nms_topN, nms_thresh, min_size,
				anchor_scale, anchor_ratio,
				correct_transform_coords, angle_bound_on, angle_bound_lo,
				angle_bound_hi, clip_angle_thresh);
#endif
		}
		else if (!strcmp(layerName, "proposal_fpn4"))
		{
			int feat_stride = 16;
			int pre_nms_topN = 1000;
			int post_nms_topN = 1000;
			float nms_thresh = 0.7;
			int min_size = 0;
			std::vector<float> anchor_scale({ 4.0f });
			std::vector<float> anchor_ratio({ 0.5f, 1.0f, 2.0f });
#ifdef USE_CaffeProposalWithAnchorPlugin
			cur_plugin = ctools_create_caffe_proposal_with_anchor_plugin(feat_stride, pre_nms_topN, post_nms_topN, nms_thresh, min_size, anchor_scale, anchor_ratio);
#else
			bool correct_transform_coords = 1;
			bool angle_bound_on = 1;
			int angle_bound_lo = -90;
			int angle_bound_hi = 90;
			float clip_angle_thresh = 1.0;
			cur_plugin = ctools_create_caffe_generate_proposals_plugin(1.0f / feat_stride, feat_stride, pre_nms_topN,
				post_nms_topN, nms_thresh, min_size,
				anchor_scale, anchor_ratio,
				correct_transform_coords, angle_bound_on, angle_bound_lo,
				angle_bound_hi, clip_angle_thresh);
#endif
		}
		else if (!strcmp(layerName, "proposal_fpn5"))
		{
			int feat_stride = 32;
			int pre_nms_topN = 1000;
			int post_nms_topN = 1000;
			float nms_thresh = 0.7;
			int min_size = 0;
			std::vector<float> anchor_scale({ 4.0f });
			std::vector<float> anchor_ratio({ 0.5f, 1.0f, 2.0f });
#ifdef USE_CaffeProposalWithAnchorPlugin
			cur_plugin = ctools_create_caffe_proposal_with_anchor_plugin(feat_stride, pre_nms_topN, post_nms_topN, nms_thresh, min_size, anchor_scale, anchor_ratio);
#else
			bool correct_transform_coords = 1;
			bool angle_bound_on = 1;
			int angle_bound_lo = -90;
			int angle_bound_hi = 90;
			float clip_angle_thresh = 1.0;
			cur_plugin = ctools_create_caffe_generate_proposals_plugin(1.0f / feat_stride, feat_stride, pre_nms_topN,
				post_nms_topN, nms_thresh, min_size,
				anchor_scale, anchor_ratio,
				correct_transform_coords, angle_bound_on, angle_bound_lo,
				angle_bound_hi, clip_angle_thresh);
#endif
		}
		else if (!strcmp(layerName, "proposal_fpn6"))
		{
			int feat_stride = 64;
			int pre_nms_topN = 1000;
			int post_nms_topN = 1000;
			float nms_thresh = 0.7;
			int min_size = 0;
			std::vector<float> anchor_scale({ 4.0f });
			std::vector<float> anchor_ratio({ 0.5f, 1.0f, 2.0f });
#ifdef USE_CaffeProposalWithAnchorPlugin
			cur_plugin = ctools_create_caffe_proposal_with_anchor_plugin(feat_stride, pre_nms_topN, post_nms_topN, nms_thresh, min_size, anchor_scale, anchor_ratio);
#else
			bool correct_transform_coords = 1;
			bool angle_bound_on = 1;
			int angle_bound_lo = -90;
			int angle_bound_hi = 90;
			float clip_angle_thresh = 1.0;
			cur_plugin = ctools_create_caffe_generate_proposals_plugin(1.0f / feat_stride, feat_stride, pre_nms_topN,
				post_nms_topN, nms_thresh, min_size,
				anchor_scale, anchor_ratio,
				correct_transform_coords, angle_bound_on, angle_bound_lo,
				angle_bound_hi, clip_angle_thresh);
#endif
		}
		else if (!strcmp(layerName, "roi_align_fpn3"))
		{
			int pooled_height = 7;
			int pooled_width = 7;
			float spatial_scale = 0.125;
			int sample_num = 2;
			cur_plugin = ctools_create_caffe_roi_align_plugin(pooled_height, pooled_width, spatial_scale, sample_num);
		}
		else if (!strcmp(layerName, "roi_align_fpn4"))
		{
			int pooled_height = 7;
			int pooled_width = 7;
			float spatial_scale = 0.0625;
			int sample_num = 2;
			cur_plugin = ctools_create_caffe_roi_align_plugin(pooled_height, pooled_width, spatial_scale, sample_num);
		}
		else if (!strcmp(layerName, "roi_align_fpn5"))
		{
			int pooled_height = 7;
			int pooled_width = 7;
			float spatial_scale = 0.03125;
			int sample_num = 2;
			cur_plugin = ctools_create_caffe_roi_align_plugin(pooled_height, pooled_width, spatial_scale, sample_num);
		}
		else if (!strcmp(layerName, "concat_roi_feat"))
		{
			cur_plugin = ctools_create_caffe_fpn_concat_plugin();
		}
		else if (!strcmp(layerName, "batchreindex"))
		{
			cur_plugin = ctools_create_caffe_batch_reindex_plugin();
		}
		else if (!strcmp(layerName, "bbox_transform"))
		{
			cur_plugin = ctools_create_caffe_bbox_transform_plugin();
		}	
		else
		{
			std::cout << layerName << std::endl;
			assert(0);
		}
		assert(cur_plugin.get());
		plugins_.emplace_back(cur_plugin);
		return cur_plugin.get();
	}

	virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) override
	{
		//printf("%s<%d>:create_plugin->%s\n", __FILE__, __LINE__, layerName);
		// there's no way to pass parameters through from the model definition, so we have to define it here explicitly
		assert(isPlugin(layerName));
		std::shared_ptr<IPlugin> cur_plugin{ nullptr };

		if (!strcmp(layerName, "ColDisFpnRpnProposals"))
		{
			cur_plugin = ctools_create_caffe_collect_distribute_fpn_rpn_proposals_plugin(serialData, serialLength);
		}
		else if (!strcmp(layerName, "proposal_fpn3"))
		{
#ifdef USE_CaffeProposalWithAnchorPlugin
			cur_plugin = ctools_create_caffe_proposal_with_anchor_plugin(serialData, serialLength);
#else
			cur_plugin = ctools_create_caffe_generate_proposals_plugin(serialData, serialLength);
#endif
		}
		else if (!strcmp(layerName, "proposal_fpn4"))
		{
#ifdef USE_CaffeProposalWithAnchorPlugin
			cur_plugin = ctools_create_caffe_proposal_with_anchor_plugin(serialData, serialLength);
#else
			cur_plugin = ctools_create_caffe_generate_proposals_plugin(serialData, serialLength);
#endif
		}
		else if (!strcmp(layerName, "proposal_fpn5"))
		{
#ifdef USE_CaffeProposalWithAnchorPlugin
			cur_plugin = ctools_create_caffe_proposal_with_anchor_plugin(serialData, serialLength);
#else
			cur_plugin = ctools_create_caffe_generate_proposals_plugin(serialData, serialLength);
#endif
		}
		else if (!strcmp(layerName, "proposal_fpn6"))
		{
#ifdef USE_CaffeProposalWithAnchorPlugin
			cur_plugin = ctools_create_caffe_proposal_with_anchor_plugin(serialData, serialLength);
#else
			cur_plugin = ctools_create_caffe_generate_proposals_plugin(serialData, serialLength);
#endif
		}
		else if (!strcmp(layerName, "roi_align_fpn3"))
		{
			cur_plugin = ctools_create_caffe_roi_align_plugin(serialData, serialLength);
		}
		else if (!strcmp(layerName, "roi_align_fpn4"))
		{
			cur_plugin = ctools_create_caffe_roi_align_plugin(serialData, serialLength);
		}
		else if (!strcmp(layerName, "roi_align_fpn5"))
		{
			cur_plugin = ctools_create_caffe_roi_align_plugin(serialData, serialLength);
		}
		else if (!strcmp(layerName, "concat_roi_feat"))
		{
			cur_plugin = ctools_create_caffe_fpn_concat_plugin(serialData, serialLength);
		}
		else if (!strcmp(layerName, "batchreindex"))
		{
			cur_plugin = ctools_create_caffe_batch_reindex_plugin(serialData, serialLength);
		}
		else if (!strcmp(layerName, "bbox_transform"))
		{
			cur_plugin = ctools_create_caffe_bbox_transform_plugin(serialData, serialLength);
		}
		else
		{
			std::cout << layerName << std::endl;
			assert(0);
		}
		assert(cur_plugin.get());
		plugins_.emplace_back(cur_plugin);
		return cur_plugin.get();
	}

	// User application destroys plugin when it is safe to do so.
	// Should be done after consumers of plugin (like ICudaEngine) are destroyed.
	void destroyPlugin()
	{
		for (size_t i = 0; i < plugins_.size(); i++)
		{
			plugins_[i].reset();
		}
		plugins_.clear();
	}

	std::vector< std::shared_ptr<IPlugin> > plugins_;

};


#endif // __SSD_PLUGIN_FACTORY_H__