cnn_cls.cpp 5.51 KB
#include "cnn_cls.h"
#include <iostream>
#include "acl/acl.h"
#include "model_process.h"
#include "sy_errorinfo.h"
#include <time.h>
#include <sys/time.h>
#include <algorithm>

using namespace std;

namespace atlas_utils {

int CnnCls::Init(const char* modelPath) {
    ACL_CALL(aclrtGetRunMode(&runMode_), SY_SUCCESS, SY_FAILED);//获取当前昇腾AI软件栈的运行模式,根据不同的运行模式,后续的接口调用方式不同
    ACL_CALL(model_.LoadModelFromFileWithMem(modelPath), SY_SUCCESS, SY_FAILED);//从文件加载离线模型数据,由用户自行管理模型运行的内存
    ACL_CALL(model_.CreateDesc(), SY_SUCCESS, SY_FAILED);//获取模型的描述信息
    ACL_CALL(model_.CreateOutput(outDims_), SY_SUCCESS, SY_FAILED);
    ACL_CALL(model_.GetInputDims(inDims_), SY_SUCCESS, SY_FAILED);
    modelHeight_ = inDims_[0][1];
    modelWidth_ = inDims_[0][2];

    return SY_SUCCESS;
}

// double msecond1() {
//     struct timeval tv;
//     gettimeofday(&tv, 0);
//     return (tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0);
// }

int CnnCls::Inference(ImageData& input) {
    model_.CreateInput(input.data.get(), input.size);
    //double t1, t2;
    //t1 = msecond1();
    ACL_CALL(model_.Execute(), SY_SUCCESS, SY_FAILED);
    //t2 = msecond1();
    //printf("debug forward time: %.2f\n", t2 - t1);
    model_.DestroyInput(); //需调用CreateInput的销毁类接口DestroyInput!!!
    return SY_SUCCESS;
}

int CnnCls::GetInputWidth() {
    return modelWidth_;
}

int CnnCls::GetInputHeight() {
    return modelHeight_;
}

int CnnCls::PostProcess(vector<float>& results) {
    aclmdlDataset* modelOutput = model_.GetModelOutputData();
    int outDatasetNum = aclmdlGetDatasetNumBuffers(modelOutput);
    for (int i = 0; i < outDatasetNum; i++) {
        aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(modelOutput, i);
        if (dataBuffer == nullptr) {
            return SY_FAILED;
        }
        uint32_t dataBufferSize = aclGetDataBufferSize(dataBuffer);
        void* data = aclGetDataBufferAddr(dataBuffer);
        if (data == nullptr) {
            return SY_FAILED;
        }
        
        int length = dataBufferSize/sizeof(float);
        float outInfo[length];
        
        if (runMode_ == ACL_HOST) {
            ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_HOST),
                     ACL_SUCCESS, SY_FAILED);
            } else {
                     ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_DEVICE),ACL_SUCCESS, SY_FAILED);
               //return SY_FAILED;
            }

        int argmax = std::distance(outInfo, std::max_element(outInfo, outInfo + length)); 
        if(outInfo[argmax] < config.confThr){
            // printf("outInfo[argmax]:%f\n",outInfo[argmax]);
            // printf("config.confThr:%f\n",config.confThr); 
            results.emplace_back(-1);
            results.emplace_back(0);
            // INFO_LOG("vColor is low confidence!");
        }
        else{
            results.emplace_back(argmax);
            results.emplace_back(outInfo[argmax]);
        }
        /*
        for(uint32_t b = 0; b < length; b++) {
           
           results.emplace_back(outInfo[b]);
        }*/
    }

    
    return SY_SUCCESS;
}

int CnnCls::PostProcess_batch(vector<vector<float>>& results) {
    aclmdlDataset* modelOutput = model_.GetModelOutputData();
    int outDatasetNum = aclmdlGetDatasetNumBuffers(modelOutput);
    const int batchsize       = outDims_[0][0];

    // 结果拷贝
    vector<vector<float>> all_res;
    for (int i = 0; i < outDatasetNum; i++) {
        aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(modelOutput, i);
        if (dataBuffer == nullptr) {
            return SY_FAILED;
        }
        uint32_t dataBufferSize = aclGetDataBufferSize(dataBuffer);
        void* data = aclGetDataBufferAddr(dataBuffer);
        if (data == nullptr) {
            return SY_FAILED;
        }
        
        int length = dataBufferSize/sizeof(float);
        float outInfo[length];
            
        if (runMode_ == ACL_HOST) {
            ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_HOST),ACL_SUCCESS, SY_FAILED);
            } else {
                ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_DEVICE),ACL_SUCCESS, SY_FAILED);
            }

        vector<float> res;
        res.assign(outInfo, outInfo + length);
        all_res.emplace_back(res);
    } 

    // 按[b][cls]的形式返回
    for (int b = 0; b < batchsize; b ++) {
        vector<float> result;
        for (const auto& outInfo : all_res) {
            int single_length = outInfo.size() / batchsize;
            // printf("batchsize:%d,single_length:%d\n",batchsize, single_length);
           
            int argmax = std::distance(outInfo.begin()+b*single_length, std::max_element(outInfo.begin()+b*single_length, outInfo.begin() + (b+1)*single_length)); 
            float score = outInfo[b*single_length+argmax];
            if(score < config.confThr) {
                // printf("score:%f, confThr\n",score, config.confThr); 
                result.emplace_back(-1);
                result.emplace_back(0);
            }
            else {
                result.emplace_back(argmax);
                result.emplace_back(score); 
            }    
        }
        results.emplace_back(result);
    }
     
    return SY_SUCCESS;
}

void CnnCls::Release() {
    model_.Unload();
    model_.DestroyDesc();
    model_.DestroyOutput();
}

}