cnn_extractor.cpp 2.95 KB
#include "cnn_extractor.h"
#include <iostream>
#include "acl/acl.h"
#include "model_process.h"
#include "sy_errorinfo.h"
#include <time.h>
#include <sys/time.h>
#include <algorithm>

using namespace std;

namespace atlas_utils {

int CNNExtract::Init(const char* modelPath) {
    ACL_CALL(aclrtGetRunMode(&runMode_), SY_SUCCESS, SY_FAILED);//获取当前昇腾AI软件栈的运行模式,根据不同的运行模式,后续的接口调用方式不同
    ACL_CALL(model_.LoadModelFromFileWithMem(modelPath), SY_SUCCESS, SY_FAILED);//从文件加载离线模型数据,由用户自行管理模型运行的内存
    ACL_CALL(model_.CreateDesc(), SY_SUCCESS, SY_FAILED);//获取模型的描述信息
    ACL_CALL(model_.CreateOutput(outDims_), SY_SUCCESS, SY_FAILED);
    ACL_CALL(model_.GetInputDims(inDims_), SY_SUCCESS, SY_FAILED);
    modelHeight_ = inDims_[0][1];
    modelWidth_ = inDims_[0][2];

    return SY_SUCCESS;
}


int CNNExtract::Inference(ImageData& input) {
    model_.CreateInput(input.data.get(), input.size);
    ACL_CALL(model_.Execute(), SY_SUCCESS, SY_FAILED);
    model_.DestroyInput(); //需调用CreateInput的销毁类接口DestroyInput!!!
    return SY_SUCCESS;
}

int CNNExtract::GetInputWidth() {
    return modelWidth_;
}

int CNNExtract::GetInputHeight() {
    return modelHeight_;
}

int CNNExtract::PostProcess(vector<float>& results) {
    aclmdlDataset* modelOutput = model_.GetModelOutputData();
    int outDatasetNum = aclmdlGetDatasetNumBuffers(modelOutput);
    for (int i = 0; i < outDatasetNum; i++) {
        aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(modelOutput, i);
        if (dataBuffer == nullptr) {
            return SY_FAILED;
        }
        uint32_t dataBufferSize = aclGetDataBufferSize(dataBuffer);
        void* data = aclGetDataBufferAddr(dataBuffer);
        if (data == nullptr) {
            return SY_FAILED;
        }
        
        int length = dataBufferSize/sizeof(float); 
        float outInfo[length];
        
        if (runMode_ == ACL_HOST) {
            ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_HOST),
                     ACL_SUCCESS, SY_FAILED);
            } else {
                     ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_DEVICE),ACL_SUCCESS, SY_FAILED);
               //return SY_FAILED;
            }

            // //归一化
            // float sum = 0.0;
            // for(int j=0;j<length;j++){
            //     sum = sum + outInfo[j]*outInfo[j];
            // }
            // sum = sqrt(sum) + 1e-6;
            // for(int j=0;j<length;j++){
            //     results.emplace_back(outInfo[j]/sum);
            // }

             
            for(uint32_t b = 0; b < length; b++) {
                results.emplace_back(outInfo[b]);
            }
      
    }

    
    return SY_SUCCESS;
}

void CNNExtract::Release() {
    model_.Unload();
    model_.DestroyDesc();
    model_.DestroyOutput();
}

}