#include "cnn_cls.h" #include #include "acl/acl.h" #include "model_process.h" #include "sy_errorinfo.h" #include #include #include using namespace std; namespace atlas_utils { int CnnCls::Init(const char* modelPath) { ACL_CALL(aclrtGetRunMode(&runMode_), SY_SUCCESS, SY_FAILED);//获取当前昇腾AI软件栈的运行模式,根据不同的运行模式,后续的接口调用方式不同 ACL_CALL(model_.LoadModelFromFileWithMem(modelPath), SY_SUCCESS, SY_FAILED);//从文件加载离线模型数据,由用户自行管理模型运行的内存 ACL_CALL(model_.CreateDesc(), SY_SUCCESS, SY_FAILED);//获取模型的描述信息 ACL_CALL(model_.CreateOutput(outDims_), SY_SUCCESS, SY_FAILED); ACL_CALL(model_.GetInputDims(inDims_), SY_SUCCESS, SY_FAILED); modelHeight_ = inDims_[0][1]; modelWidth_ = inDims_[0][2]; return SY_SUCCESS; } // double msecond1() { // struct timeval tv; // gettimeofday(&tv, 0); // return (tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0); // } int CnnCls::Inference(ImageData& input) { model_.CreateInput(input.data.get(), input.size); //double t1, t2; //t1 = msecond1(); ACL_CALL(model_.Execute(), SY_SUCCESS, SY_FAILED); //t2 = msecond1(); //printf("debug forward time: %.2f\n", t2 - t1); model_.DestroyInput(); //需调用CreateInput的销毁类接口DestroyInput!!! return SY_SUCCESS; } int CnnCls::GetInputWidth() { return modelWidth_; } int CnnCls::GetInputHeight() { return modelHeight_; } int CnnCls::PostProcess(vector& results) { aclmdlDataset* modelOutput = model_.GetModelOutputData(); int outDatasetNum = aclmdlGetDatasetNumBuffers(modelOutput); for (int i = 0; i < outDatasetNum; i++) { aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(modelOutput, i); if (dataBuffer == nullptr) { return SY_FAILED; } uint32_t dataBufferSize = aclGetDataBufferSize(dataBuffer); void* data = aclGetDataBufferAddr(dataBuffer); if (data == nullptr) { return SY_FAILED; } int length = dataBufferSize/sizeof(float); float outInfo[length]; if (runMode_ == ACL_HOST) { ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_HOST), ACL_SUCCESS, SY_FAILED); } else { ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_DEVICE),ACL_SUCCESS, SY_FAILED); //return SY_FAILED; } int argmax = std::distance(outInfo, std::max_element(outInfo, outInfo + length)); if(outInfo[argmax] < config.confThr){ // printf("outInfo[argmax]:%f\n",outInfo[argmax]); // printf("config.confThr:%f\n",config.confThr); results.emplace_back(-1); results.emplace_back(0); // INFO_LOG("vColor is low confidence!"); } else{ results.emplace_back(argmax); results.emplace_back(outInfo[argmax]); } /* for(uint32_t b = 0; b < length; b++) { results.emplace_back(outInfo[b]); }*/ } return SY_SUCCESS; } int CnnCls::PostProcess_batch(vector>& results) { aclmdlDataset* modelOutput = model_.GetModelOutputData(); int outDatasetNum = aclmdlGetDatasetNumBuffers(modelOutput); const int batchsize = outDims_[0][0]; // 结果拷贝 vector> all_res; for (int i = 0; i < outDatasetNum; i++) { aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(modelOutput, i); if (dataBuffer == nullptr) { return SY_FAILED; } uint32_t dataBufferSize = aclGetDataBufferSize(dataBuffer); void* data = aclGetDataBufferAddr(dataBuffer); if (data == nullptr) { return SY_FAILED; } int length = dataBufferSize/sizeof(float); float outInfo[length]; if (runMode_ == ACL_HOST) { ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_HOST),ACL_SUCCESS, SY_FAILED); } else { ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_DEVICE),ACL_SUCCESS, SY_FAILED); } vector res; res.assign(outInfo, outInfo + length); all_res.emplace_back(res); } // 按[b][cls]的形式返回 for (int b = 0; b < batchsize; b ++) { vector result; for (const auto& outInfo : all_res) { int single_length = outInfo.size() / batchsize; // printf("batchsize:%d,single_length:%d\n",batchsize, single_length); int argmax = std::distance(outInfo.begin()+b*single_length, std::max_element(outInfo.begin()+b*single_length, outInfo.begin() + (b+1)*single_length)); float score = outInfo[b*single_length+argmax]; if(score < config.confThr) { // printf("score:%f, confThr\n",score, config.confThr); result.emplace_back(-1); result.emplace_back(0); } else { result.emplace_back(argmax); result.emplace_back(score); } } results.emplace_back(result); } return SY_SUCCESS; } void CnnCls::Release() { model_.Unload(); model_.DestroyDesc(); model_.DestroyOutput(); } }