#include "road_extractor.h" #include #include "acl/acl.h" #include "model_process.h" #include "sy_errorinfo.h" #include #include #include using namespace std; namespace atlas_utils { int ROADExtract::Init(const char* modelPath) { ACL_CALL(aclrtGetRunMode(&runMode_), SY_SUCCESS, SY_FAILED);//获取当前昇腾AI软件栈的运行模式,根据不同的运行模式,后续的接口调用方式不同 ACL_CALL(model_.LoadModelFromFileWithMem(modelPath), SY_SUCCESS, SY_FAILED);//从文件加载离线模型数据,由用户自行管理模型运行的内存 ACL_CALL(model_.CreateDesc(), SY_SUCCESS, SY_FAILED);//获取模型的描述信息 ACL_CALL(model_.CreateOutput(outDims_), SY_SUCCESS, SY_FAILED); ACL_CALL(model_.GetInputDims(inDims_), SY_SUCCESS, SY_FAILED); modelHeight_ = inDims_[0][1]; modelWidth_ = inDims_[0][2]; return SY_SUCCESS; } int ROADExtract::Inference(ImageData& input) { model_.CreateInput(input.data.get(), input.size); ACL_CALL(model_.Execute(), SY_SUCCESS, SY_FAILED); model_.DestroyInput(); //需调用CreateInput的销毁类接口DestroyInput!!! return SY_SUCCESS; } int ROADExtract::GetInputWidth() { return modelWidth_; } int ROADExtract::GetInputHeight() { return modelHeight_; } int ROADExtract::PostProcess(vector& results) { aclmdlDataset* modelOutput = model_.GetModelOutputData(); int outDatasetNum = aclmdlGetDatasetNumBuffers(modelOutput); for (int i = 0; i < outDatasetNum; i++) { aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(modelOutput, i); if (dataBuffer == nullptr) { return SY_FAILED; } uint32_t dataBufferSize = aclGetDataBufferSize(dataBuffer); void* data = aclGetDataBufferAddr(dataBuffer); if (data == nullptr) { return SY_FAILED; } int length = dataBufferSize/sizeof(float); // float outInfo[length]; //栈上分配,受栈大小限制,数组过大可能导致栈溢出 float* outInfo = new float[length]; //堆上分配 if (runMode_ == ACL_HOST) { ACL_CALL(aclrtMemcpy(outInfo, dataBufferSize, data, dataBufferSize, ACL_MEMCPY_DEVICE_TO_HOST),ACL_SUCCESS, SY_FAILED); } else { ACL_CALL(aclrtMemcpy(outInfo, dataBufferSize, data, dataBufferSize, ACL_MEMCPY_DEVICE_TO_DEVICE),ACL_SUCCESS, SY_FAILED); } for(uint32_t b = 0; b < length; b++) { results.emplace_back(outInfo[b]); } delete[] outInfo; outInfo = nullptr; } return SY_SUCCESS; } void ROADExtract::Release() { model_.Unload(); model_.DestroyDesc(); model_.DestroyOutput(); } }