takeaway_member_cls.cpp 12.3 KB
/*
 * @Author: yangzilong
 * @Date: 2021-11-25 10:42:54
 * @Last Modified by: yangzilong
 * @Email: yangzilong@objecteye.com
 * @Description:
 */
#include <algorithm>
#include "takeaway_member_cls.hpp"
#include "../reprocessing_module/CropImg.h"


namespace ai_engine_module
{
    namespace takeaway_member_classification
    {
        algorithm_type_t TakeawayMemberCls::algor_type_ = algorithm_type_t::TAKEAWAY_MEMBER_CLASSIFICATION;


        using namespace classification;
        TakeawayMemberCls::TakeawayMemberCls()
            : task_param_manager_(nullptr)
        {

        }

        TakeawayMemberCls::~TakeawayMemberCls()
        {
            takeway_member::release(&tools_);
            if (!tools_)
            {
                delete tools_;
                tools_ = nullptr;
            }
        }

        bool TakeawayMemberCls::init(int gpu_id, char* trt_serialize_file)
        {
            init_ = false;
            common::init_params_t param;
            {
                param.mode = gpu_id >= 0 ? DEVICE_GPU : DEVICE_CPU;
                param.gpuid = gpu_id;
                param.threshold = 0.0;
                param.engine = ENGINE_TENSORRT;
                param.max_batch = MAX_BATCH;
                param.trt_serialize_file = trt_serialize_file;
            }
            // helpers::os::mkdirp(trt_serialize_file);

            int status;
            if (!(init_ = (0 == (status = takeway_member::init(&tools_, &param)))))
                LOG_ERROR("Init TakeawayMemberClsSdk failed error code is {}", status);
            else
                if (!task_param_manager_)
                    task_param_manager_ = task_param_manager::getInstance();
            return init_;
        }


        bool TakeawayMemberCls::check_initied()
        {
            if (!init_)
                LOG_ERROR("[%s:%d] call init function please.", __FILE__, __LINE__);
            return init_;
        }


        void TakeawayMemberCls::force_release_result(const task_id_t& task_id) {
            for (auto iter = id_to_result_.begin(); iter != id_to_result_.end();) {
                const auto& key = iter->first;
                if (key.task_id == task_id) {
                    auto& value = iter->second;
                    if (value.roi_img.data_ != nullptr) {
                        CHECK(cudaFree(value.roi_img.data_));
                        value.roi_img.data_ = nullptr;
                    }

                    if (value.ori_img.data_ != nullptr) {
                        CHECK(cudaFree(value.ori_img.data_));
                        value.ori_img.data_ = nullptr;
                    }
                    iter = id_to_result_.erase(iter);
                }
                else {
                    ++iter;
                }

            }
        }


        std::shared_ptr<result_data_t> TakeawayMemberCls::get_result_by_objectid(const id_t& id, bool do_erase)
        {
            auto it = id_to_result_.find(id);
            if (it == id_to_result_.end())
                return std::shared_ptr<result_data_t>(nullptr);
            std::shared_ptr<result_data_t> res = std::make_shared<result_data_t>(it->second);
            if (do_erase)
                id_to_result_.erase(id);
            return res;
        }

        bool TakeawayMemberCls::update_mstreams(const std::set<task_id_t>& taskIds, const sy_img* det_input_images, const std::vector<onelevel_det_result>& det_results)
        {
            if (!check_initied())
                return false;

            if (det_results.empty())
            {
                LOG_DEBUG("detection result is empty.");
                return false;
            }

            int n_images = det_results.size();  // or n_stream

            unsigned flattened_idx = 0;
            std::map<int, int> flattened_idx_to_batch_idx;

            /* 1. Crop & keep some interest class. */
            auto taskId_iter = taskIds.begin();
            std::vector<sy_img> flattened_imgs(0);
            std::vector<input_data_wrap_t> flattened_interest_data(0);  //
            for (int n = 0; n < n_images; ++n)
            {
                int n_interest_obj = 0;
                auto& src_img = det_input_images[n];
                auto& boxes_of_one_image = det_results[n].obj;
                for (int i = 0; i < det_results[n].obj_count; ++i)
                {
                    auto& box = boxes_of_one_image[i];
                    if (static_cast<det_class_label_t>(box.index) == det_class_label_t::MOTOCYCLE)
                    {
                        auto& taskId = *taskId_iter;
                        auto algor_param_wrap = task_param_manager_->get_task_other_param(taskId, this->algor_type_);
                        if (!algor_param_wrap)
                        {
                            LOG_ERROR("{} is nullptr when get algor param from task_param", taskId.c_str());
                            continue;
                        }
                        auto algor_param = ((algor_param_type)algor_param_wrap->algor_param);

                        input_data_wrap_t data;
                        int top = std::max(int(box.top - (IMAGE_CROP_EXPAND_RATIO * box.top)), 0);
                        int left = std::max(int(box.left - (IMAGE_CROP_EXPAND_RATIO * box.left)), 0);
                        int right = std::min(int(box.right + (IMAGE_CROP_EXPAND_RATIO * box.right)), src_img.w_);
                        int bottom = std::min(int(box.bottom + (IMAGE_CROP_EXPAND_RATIO * box.bottom)), src_img.h_);

                        int width = right - left;
                        int height = bottom - top;

                        if ((width < algor_param->pedestrian_min_width || height < algor_param->pedestrian_min_height || box.confidence < algor_param->pedestrian_confidence_threshold) ||
                            !snapshot_legal_inarea(algor_param_wrap->basic_param->algor_valid_rect, left, top, right, bottom))
                            continue;

                        data.box.top = top;
                        data.box.left = left;
                        data.box.right = right;
                        data.box.bottom = bottom;
                        data.taskId = taskId;
                        data.objId = box.id;
                        data.id = obj_key_t{ box.id, taskId, algorithm_type_t::TAKEAWAY_MEMBER_CLASSIFICATION };

                        sy_img img;
                        {
                            img.w_ = width;
                            img.h_ = height;
                            img.c_ = src_img.c_;
                        }

                        cudaError_t cuda_status;
                        const unsigned nbytes = img.c_ * img.h_ * img.w_ * sizeof(unsigned char);
                        if (CUDA_SUCCESS != (cuda_status = cudaMalloc((void**)&img.data_, nbytes)))
                        {
                            LOG_ERROR("cudaMalloc failed: {} malloc nbytes is {} mb is {} ", cudaGetErrorString(cuda_status), nbytes, nbytes / (1024 * 1024));
                            continue;
                        }

                        if (CUDA_SUCCESS != (cuda_status = cudacommon::CropImgGpu(src_img.data_, src_img.w_, src_img.h_, img.data_, left, top, width, height)))
                        {
                            LOG_ERROR("Crop image GPU failed error is {} wh is [{}, {}] ltrb is [{} {} {} {}]",
                                cudaGetErrorString(cuda_status), src_img.w_, src_img.h_, data.box.left, data.box.top, data.box.right, data.box.bottom);
                            CHECK(cudaFree(img.data_));
                            continue;
                        }
                        flattened_imgs.emplace_back(std::move(img));
                        flattened_interest_data.emplace_back(std::move(data));
                        flattened_idx_to_batch_idx[flattened_idx++] = n;
                    }
                }
                ++taskId_iter;
            }


            /* 2. collection result. */
            int n_input_image = flattened_imgs.size();
            takeway_member::results_t model_results[n_input_image];
            {
                int steps = (n_input_image + MAX_BATCH - 1) / MAX_BATCH;

                for (int step = 0; step < steps; ++step)
                {
                    int offset = step * MAX_BATCH;
                    int batch_size = (step == steps - 1) ? n_input_image - offset : MAX_BATCH;
                    takeway_member::process_batch(tools_, flattened_imgs.data() + offset, batch_size, model_results + offset);
                }
            }


            /* 3. postprocess. */
            {
                for (int n = 0; n < n_input_image; ++n)
                {
                    auto& det_result = flattened_interest_data[n];
                    auto& objId = det_result.objId;
                    if (id_to_result_.find(det_result.id) != id_to_result_.end())
                    {
                        CHECK(cudaFree(flattened_imgs[n].data_));
                        flattened_imgs[n].data_ = nullptr;
                        continue;
                    }

                    const sy_img& src_img = det_input_images[flattened_idx_to_batch_idx[n]];

                    auto algor_param_wrap = task_param_manager_->get_task_other_param(det_result.taskId, this->algor_type_);
                    if (!algor_param_wrap)
                    {
                        LOG_ERROR("{} nullptr when get algor param from task_param", det_result.taskId.c_str());
                        CHECK(cudaFree(flattened_imgs[n].data_));
                        flattened_imgs[n].data_ = nullptr;
                        continue;
                    }

                    auto algor_param = ((algor_param_type)algor_param_wrap->algor_param);
                    takeway_member::takeaway_member_label_t takeway_member_cls = static_cast<takeway_member::takeaway_member_label_t>(
                        model_results[n].multi_label_cls_result[(int)takeway_member::label_index_t::TAKEAWAY_MEMBER].category);
                    det_result.box.score = model_results[n].multi_label_cls_result[(int)takeway_member::label_index_t::TAKEAWAY_MEMBER].prob;

                    obj_key_t obj_key{ det_result.objId, det_result.taskId, algorithm_type_t::TAKEAWAY_MEMBER_CLASSIFICATION };

                    auto& e = id_to_mn_[obj_key];
                    ++e.m_frame;

                    if (takeway_member_cls != takeway_member::takeaway_member_label_t::NOT &&
                        det_result.box.score >= algor_param->threshold)
                    {
                        if (++e.n_frame == algor_param->n)
                        {
                            result_data_t result;
                            {
                                result.box = det_result.box;
                                result.taskId = det_result.taskId;
                                result.objId = det_result.objId;
#if 0
                                {
                                    result.ori_img = src_img;
                                }
#else
                                {
                                    sy_img img;
                                    {
                                        img.c_ = src_img.c_;
                                        img.h_ = src_img.h_;
                                        img.w_ = src_img.w_;
                                    }
                                    unsigned nbytes = img.c_ * img.h_ * img.w_ * sizeof(unsigned char);
                                    CHECK(cudaMalloc(&img.data_, nbytes));
                                    CHECK(cudaMemcpy(img.data_, src_img.data_, nbytes, cudaMemcpyDeviceToDevice));
                                    result.ori_img = std::move(img);
                                }
#endif
                                result.roi_img = std::move(flattened_imgs[n]);
                                result.category = (int)takeway_member_cls;
                            }
                            id_to_result_.emplace(obj_key, std::move(result));
                            goto _continue;
                        }
                    }

                    if (e.m_frame == algor_param->m)
                        e.reset();

                    CHECK(cudaFree(flattened_imgs[n].data_));
                _continue:
                    {

                    }

                }
            }


            return true;

        }

    }  // namespace takeaway_member_clasiication

} // namespace ai_engine_module