VPTProcess.cpp 7.49 KB
#include "VPTProcess.h"
#include "../common/logger.hpp"
#include "../ai_platform/task_param_manager.h"

#include <stdlib.h>
#include <time.h>
#include <fstream>

#include "vpt.h"
#include "../ai_platform/macro_definition.h"
#include "../ai_platform/det_obj_header.h"

#include "opencv2/opencv.hpp"

#include "../util/vpc_util.h"

VPTProcess::VPTProcess(){
    m_max_batchsize = 16;
}

VPTProcess::~VPTProcess(){
    release();
}

/* 算法初始化 */
int VPTProcess::init(VPTProcess_PARAM vparam){

    // string model_path = vparam.model_dir + "/models/vpt230323_310p.om" ;
    string model_path = vparam.model_dir + "/models/vpt230323_b16_310p.om" ;

    LOG_INFO("vpt 版本:{}  模型路径:{}", vpt_get_version(), model_path);

    vpt_param param;
    char modelNames[100];
    strcpy(modelNames, model_path.c_str());
    // param.modelNames = modelNames;
    param.modelNames_b = modelNames;
    param.threshold = vparam.threshold;
    param.devId = vparam.gpuid;
    param.isTrk = false;
    param.max_batch = 16;

    m_devId = param.devId;
    ACL_CALL(aclrtSetDevice(m_devId), ACL_SUCCESS, -1);
    ACL_CALL(aclrtCreateContext(&m_algorthim_ctx, m_devId), ACL_SUCCESS, -1);

    int ret = vpt_init(&m_det_handle, param);
    if(ret != 0){
        LOG_DEBUG("vpt init error.");
        return -1;
    }

    return 0;
}

/* 算法计算 */
int VPTProcess::process_gpu(sy_img * batch_img, vector<string>& tasklist,
					vector<onelevel_det_result>& result, vector<vector<int>>& deleteObjectID, vector<vector<onelevel_det_result>>& unUsedResult)
{
	int batchsize = tasklist.size();

	if (result.empty())
		result.resize(batchsize);

	/* 结果结构体初始化 */
	vpt_result *vpt_det_result = new vpt_result[batchsize];
	for (int b = 0; b < batchsize; b++){
		vpt_det_result[b].obj_count_ = 0;
		vpt_det_result[b].obj_results_ = new vpt_obj_result[MAX_DET_COUNT];
	}

    do{
        /* 路数太多时 按照最大batchsize数 拆批次运行 */
        int cur_batch_size = m_max_batchsize;
        int cycleTimes = batchsize / cur_batch_size + (batchsize % cur_batch_size == 0 ? 0 : 1);

        for (int c = 0; c < cycleTimes; c++){

            int real_batchsize = c == cycleTimes - 1 ? (batchsize - cur_batch_size*c) : cur_batch_size;
            int startbatch = c*cur_batch_size;

            vpt_result *real_res = vpt_det_result + startbatch;

            // aclrtSetDevice(m_devId);
            int ret = aclrtSetCurrentContext(m_algorthim_ctx);
            if(ACL_SUCCESS != ret){
                break;
            }
            // ret = vpt_batch(m_det_handle, batch_img + startbatch, real_batchsize, real_res);
            ret = vpt_batchV2(m_det_handle, batch_img + startbatch, real_batchsize, real_res);
            if(ret != 0){
                break;
            }
        }

        vector <vector< vector <float>>> detectResult(batchsize);  // sort

        /* 将检测的结果放进数组 转换为跟踪的输入需要(若为人脸 则检测结果可能跟多,比如需要带上ldmk点) */
        // filter by threshold.
        for (int b = 0; b < batchsize; b++)
        {
            vpt_result cur_result = vpt_det_result[b];

            for (int c = 0; c < cur_result.obj_count_ && c < MAX_OBJ_COUNT; c++)
            {
                float x1 = vpt_det_result[b].obj_results_[c].obj_rect.left_;
                float y1 = vpt_det_result[b].obj_results_[c].obj_rect.top_;
                float x2 = vpt_det_result[b].obj_results_[c].obj_rect.left_ + vpt_det_result[b].obj_results_[c].obj_rect.width_;
                float y2 = vpt_det_result[b].obj_results_[c].obj_rect.top_ + vpt_det_result[b].obj_results_[c].obj_rect.height_;

                float class_id = vpt_det_result[b].obj_results_[c].obj_index;
                float score = vpt_det_result[b].obj_results_[c].obj_score;

                if (score >= THRESHOLD)
                {
                    vector <float> obj;
                    obj.push_back(x1);
                    obj.push_back(y1);
                    obj.push_back(x2);
                    obj.push_back(y2);
                    obj.push_back(score);
                    obj.push_back(class_id);
                    detectResult[b].push_back(obj);
                }
            }
        }

        bool isUseDet = true;
        for (size_t detectIndex = 0; detectIndex < batchsize; detectIndex++) {
            string task_id = tasklist[detectIndex];

            if (! taskTrackers[task_id].tracker.GetState())
                continue;

            Sort &cur_sort = taskTrackers[task_id].tracker;
            isUseDet = true;
            
            const float maxLen = std::sqrt(batch_img[detectIndex].w_ * batch_img[detectIndex].w_ + batch_img[detectIndex].h_ * batch_img[detectIndex].h_); //-modified by zsh 220719
            /* FusionInterval是跳帧参数,以十类人车物为例,一般跳5帧,所以第一帧检测,后续四帧纯跟踪 */
            for (int j = 0; j < taskTrackers[task_id].tracker.FusionInterval; j++)
            {
                /* 跟踪:第一帧 带检测框信息的跟踪,取结果返回 */
                if (j == 0)
                {
                    // int objCount = cur_sort.update_v2(isUseDet, /*save lk = */false, /*center_dist = */true, maxLen, detectResult[detectIndex], result[detectIndex].obj, deleteObjectID[detectIndex]);
                    int objCount = cur_sort.update_v3(isUseDet, /*save lk = */false, /*center_dist = */false, maxLen, detectResult[detectIndex], result[detectIndex].obj, deleteObjectID[detectIndex]);
                    result[detectIndex].obj_count = objCount;
                    result[detectIndex].task_id = task_id;

                    vector<vector<float>>().swap(detectResult[detectIndex]);
                    detectResult[detectIndex].clear();
                    isUseDet = false;
                } else  /* 跟踪:后四帧 纯粹跟踪 纯跟踪结果不返回 */
                {
                    onelevel_det_result un_result;
                    //un_result.obj_count = cur_sort.update(isUseDet, false, detectResult[detectIndex], un_result.obj, deleteObjectID[detectIndex]);
                    // un_result.obj_count = cur_sort.update_v2(isUseDet, false, true, maxLen, detectResult[detectIndex], un_result.obj, deleteObjectID[detectIndex]);
                    un_result.obj_count = cur_sort.update_v3(isUseDet, false, false, maxLen, detectResult[detectIndex], un_result.obj, deleteObjectID[detectIndex]);
                }
            }
        }

        vector <vector< vector <float>>>().swap(detectResult);  // free memory.
    } while (0);
    
	if(vpt_det_result){
		for (int b = 0; b < batchsize; b++){
			delete[] vpt_det_result[b].obj_results_;
		}
		delete[] vpt_det_result;
	}

	return 0;
}


/* 算法句柄 资源释放 */
void VPTProcess::release(){
    aclError ret;
    ret = aclrtSetDevice(m_devId);
    aclrtSetCurrentContext(m_algorthim_ctx);

    if (m_det_handle){
        vpt_release(&m_det_handle);
        m_det_handle = NULL;
    }
    if(m_algorthim_ctx){
        aclrtDestroyContext(m_algorthim_ctx);
    }
        
    ret = aclrtResetDevice(m_devId);
    if (ret != ACL_SUCCESS) {
        LOG_ERROR("reset device failed");
    }
}

// 221117byzsh
void VPTProcess::addTaskTracker(const string taskID, double rWidth, double rHeight, int skip_frame)
{
	TaskTracker t;
	t.TaskID = taskID;
	t.ratioWidth = rWidth;
	t.ratioHeight = rHeight;
	t.tracker.FusionInterval = skip_frame;

	taskTrackers[taskID] = t;
}

/* 任务结束跟踪器 */
bool VPTProcess::finishTaskTracker(const string taskID)
{
    taskTrackers.erase(taskID);
	return true;
}