tricycle_manned_process.cpp 12.8 KB
#include <algorithm>
#include "./tricycle_manned_process.h"
#include <cmath>
#include "../decoder/interface/DeviceMemory.hpp"
#include "../common/logger.hpp"
#include "../ai_platform/mvpt_process_assist.h"


namespace ai_engine_module
{
    namespace tricycle_manned_process
    {
        algorithm_type_t TricycleMannedProcess::algor_type_ = algorithm_type_t::TRICYCLE_MANNED;

        TricycleMannedProcess::TricycleMannedProcess()
            : task_param_manager_(nullptr)
        {

        }

        TricycleMannedProcess::~TricycleMannedProcess()
        {
            if (tools_) {
                hs_tri_release(&tools_);
                tools_ = nullptr;
            }
            if (m_algorthim_ctx) {
                aclrtDestroyContext(m_algorthim_ctx);
            }
        }

        bool TricycleMannedProcess::init(int gpu_id, string models_dir)
        {
            init_ = false;

            // string model_path = models_dir + "/models/hs/hs_tricycle_310p.om" ;
            string model_path = models_dir + "/models/hs/hs_tricycle_b8_310p.om" ;
            LOG_INFO("hs_tri 版本:{}  模型路径:{}", hs_tri_getversion(), model_path);

            hs_tri_param param;
            char modelNames[100];
            strcpy(modelNames, model_path.c_str());
            // param.modelNames = modelNames;
            param.modelNames_b = modelNames;
            param.thresld = 0.3;
            param.devId = gpu_id;
            param.max_batch = 8;
          
            m_devId = param.devId;
            ACL_CALL(aclrtSetDevice(m_devId), ACL_SUCCESS, -1);
            ACL_CALL(aclrtCreateContext(&m_algorthim_ctx, m_devId), ACL_SUCCESS, -1);
            
            int status;
            if (!(init_ = (0 == (status = hs_tri_init(&tools_, param)))))
                LOG_ERROR("Init TricycleMannedProcessSdk failed error code is {}", status);
            else
                if (!task_param_manager_)
                    task_param_manager_ = task_param_manager::getInstance();
            return init_;
        }


        bool TricycleMannedProcess::check_initied()
        {
            if (!init_)
                LOG_ERROR("[%s:%d] call init function please.", __FILE__, __LINE__);
            return init_;
        }


        void TricycleMannedProcess::force_release_result(const task_id_t& task_id) {
            for (auto iter = id_to_result_.begin(); iter != id_to_result_.end();) {
                const auto& key = iter->first;
                if (key.task_id == task_id) {
                    auto& value = iter->second;
                    if (value.origin_img_desc != nullptr) {
                        VPCUtil::vpc_pic_desc_release(value.origin_img_desc);
                    }

                    if (value.roi_img_desc != nullptr) {
                        VPCUtil::vpc_pic_desc_release(value.roi_img_desc);
                    }
                    iter = id_to_result_.erase(iter);
                }
                else {
                    ++iter;
                }

            }
            for (auto iter = id_to_mn_.begin(); iter != id_to_mn_.end();) {
                const auto& key = iter->first;
                if (key.task_id == task_id) { iter = id_to_mn_.erase(iter);}
                else { ++iter; }
            }
        }

        std::shared_ptr<results_data_t> TricycleMannedProcess::get_result_by_objectid(const id_t& id, bool do_erase)
        {
            auto it = id_to_result_.find(id);
            if (it == id_to_result_.end())
                return std::shared_ptr<results_data_t>(nullptr);
            std::shared_ptr<results_data_t> res = std::make_shared<results_data_t>(it->second);
            if (do_erase) {
                id_to_result_.erase(id);
                if (id_to_mn_.count(id)) id_to_mn_.erase(id);
            }
            return res;
        }
        
        bool TricycleMannedProcess::update_mstreams(const std::vector<task_id_t>& taskIds, vector<DeviceMemory*> vec_det_input_images, const std::vector<onelevel_det_result> &det_results) 
        {
            if (!check_initied())
                return false;

            if (det_results.empty())
            {
                LOG_DEBUG("detection result is empty.");
                return false;
            }

            int n_images = det_results.size();  // or n_stream

            unsigned flattened_idx = 0;
            std::map<int, int> flattened_idx_to_batch_idx;

            /* 1. Crop & keep some interest class. */
            auto taskId_iter = taskIds.begin();
            std::vector<sy_img> flattened_imgs(0);
            std::vector<vpc_img_info> flattened_vpc_imgs(0);
            std::vector<input_data_wrap_t> flattened_interest_data(0);  //
            VPCUtil* pVpcUtil = VPCUtil::getInstance();
            for (int n = 0; n < n_images; ++n)
            {
                int n_interest_obj = 0;
                auto& src_img = vec_det_input_images[n];
                int src_img_w = src_img->getWidth();
                int src_img_h = src_img->getHeight();

                auto& boxes_of_one_image = det_results[n].obj;
                for (int i = 0; i < det_results[n].obj_count; ++i)
                {
                    auto& box = boxes_of_one_image[i];
                    if (static_cast<det_class_label_t>(box.index) == det_class_label_t::TRICYCLE)
                    {
                        auto& taskId = *taskId_iter;
                        auto algor_param_wrap = task_param_manager_->get_task_other_param(taskId, this->algor_type_);
                        if (!algor_param_wrap)
                        {
                            LOG_ERROR("{} is nullptr when get algor param from task_param", taskId.c_str());
                            continue;
                        }
                        auto algor_param = ((algor_param_type)algor_param_wrap->algor_param);

                        input_data_wrap_t data;
                        int top = std::max(int(box.top - (IMAGE_CROP_EXPAND_RATIO * box.top)), 0);
                        int left = std::max(int(box.left - (IMAGE_CROP_EXPAND_RATIO * box.left)), 0);
                        int right = std::min(int(box.right + (IMAGE_CROP_EXPAND_RATIO * box.right)), src_img_w);
                        int bottom = std::min(int(box.bottom + (IMAGE_CROP_EXPAND_RATIO * box.bottom)), src_img_h);

                        int width = right - left;
                        int height = bottom - top;

                        if ((width < algor_param->obj_min_width || height < algor_param->obj_min_height || box.confidence < algor_param->obj_confidence_threshold) ||
                            !snapshot_legal_inarea(algor_param_wrap->basic_param->algor_valid_rect, left, top, right, bottom))
                            continue;

                        data.box.top = top;
                        data.box.left = left;
                        data.box.right = right;
                        data.box.bottom = bottom;
                        data.box.score = box.confidence;
                        data.taskId = taskId;
                        data.objId = box.id;
                        data.id = obj_key_t{ box.id, taskId, algorithm_type_t::TRICYCLE_MANNED };

                        // 抠图
                        video_object_info obj;
                        strcpy(obj.task_id, taskId.c_str());
                        obj.object_id = box.id;
                        obj.left = left;    obj.top = top;
                        obj.right = right;  obj.bottom = bottom;

                        vpc_img_info img_info = pVpcUtil->crop(src_img, obj);

                        sy_img img;
                        img.w_ = width;
                        img.h_ = height;
                        img.c_ = src_img->getChannel();
                        
                        if (img_info.pic_desc != nullptr) {
                            void *outputDataDev = acldvppGetPicDescData(img_info.pic_desc);
                            img.data_ = reinterpret_cast<unsigned char*>(outputDataDev);
                        }
                        else {
                            LOG_ERROR("Crop image NPU failed wh is [{}, {}] ltrb is [{} {} {} {}]",
                                src_img_w, src_img_h, data.box.left, data.box.top, data.box.right, data.box.bottom);
                            continue;
                        }
                      
                        flattened_imgs.emplace_back(std::move(img));
                        flattened_vpc_imgs.emplace_back(std::move(img_info));
                        flattened_interest_data.emplace_back(std::move(data));
                        flattened_idx_to_batch_idx[flattened_idx++] = n;
                    }
                }
                ++taskId_iter;
            }

            int ret = aclrtSetCurrentContext(m_algorthim_ctx);
            if (ACL_SUCCESS != ret) {
                return false;
            }
            /* 2. collection result. */
            int n_input_image = flattened_imgs.size();
            hs_tri_result model_results[n_input_image];
            {
                int steps = (n_input_image + MAX_BATCH - 1) / MAX_BATCH;
                for (int step = 0; step < steps; ++step)
                {
                    int offset = step * MAX_BATCH;
                    int batch_size = (step == steps - 1) ? n_input_image - offset : MAX_BATCH;
                    // hs_tri_process_batch(tools_, flattened_imgs.data() + offset, batch_size, model_results + offset);
                    hs_tri_process_batchV2(tools_, flattened_imgs.data() + offset, batch_size, model_results + offset);
                }
            }

            /* 3. postprocess. */
            {
                for (int n = 0; n < n_input_image; ++n)
                {
                    auto& det_result = flattened_interest_data[n];
                    auto& objId = det_result.objId;
                    if (id_to_result_.find(det_result.id) != id_to_result_.end())
                    {
                        VPCUtil::vpc_img_release(flattened_vpc_imgs[n]); //flattened_imgs[n].data_
                        flattened_imgs[n].data_ = nullptr;
                        continue;
                    }

                    const auto& src_img = vec_det_input_images[flattened_idx_to_batch_idx[n]];

                    auto algor_param_wrap = task_param_manager_->get_task_other_param(det_result.taskId, this->algor_type_);
                    if (!algor_param_wrap)
                    {
                        LOG_ERROR("{} nullptr when get algor param from task_param", det_result.taskId.c_str());
                        VPCUtil::vpc_img_release(flattened_vpc_imgs[n]); //flattened_imgs[n].data_
                        flattened_imgs[n].data_ = nullptr;
                        continue;
                    }

                    auto algor_param = ((algor_param_type)algor_param_wrap->algor_param);

                    int hs_count = model_results[n].objcount;

                    obj_key_t obj_key{ det_result.objId, det_result.taskId, algorithm_type_t::TRICYCLE_MANNED };

                    auto& e = id_to_mn_[obj_key];
                    ++e.m_frame;

                    if (hs_count >= algor_param->hs_count_threshold)
                    {
                        if (++e.n_frame == algor_param->n)
                        {
                            results_data_t result;
                            {
                                result.box = det_result.box;
                                result.taskId = det_result.taskId;
                                result.objId = det_result.objId;
#if 0 /*暂不保存报警时刻的抓拍图,有需要再启用*/
                                // 原图
                                vpc_img_info src_img_info = VPCUtil::vpc_devMem2vpcImg(src_img);
                                result.origin_img_desc = src_img_info.pic_desc;
                                // 抠图
                                result.roi_img_desc = flattened_vpc_imgs[n].pic_desc;
#else
                                VPCUtil::vpc_img_release(flattened_vpc_imgs[n]);
#endif
                            }
                            id_to_result_.emplace(obj_key, std::move(result));
                            goto _continue;
                        }
                    }

                    if (e.m_frame == algor_param->m)
                        e.reset();

                    VPCUtil::vpc_img_release(flattened_vpc_imgs[n]); //flattened_imgs[n].data_
                _continue:
                    {

                    }

                }
            }

            return true;
        }

    }  // namespace tricycle_manned_process

} // namespace ai_engine_module