road_extractor.cpp 2.71 KB
#include "road_extractor.h"
#include <iostream>
#include "acl/acl.h"
#include "model_process.h"
#include "sy_errorinfo.h"
#include <time.h>
#include <sys/time.h>
#include <algorithm>

using namespace std;

namespace atlas_utils {

int ROADExtract::Init(const char* modelPath) {
    ACL_CALL(aclrtGetRunMode(&runMode_), SY_SUCCESS, SY_FAILED);//获取当前昇腾AI软件栈的运行模式,根据不同的运行模式,后续的接口调用方式不同
    ACL_CALL(model_.LoadModelFromFileWithMem(modelPath), SY_SUCCESS, SY_FAILED);//从文件加载离线模型数据,由用户自行管理模型运行的内存
    ACL_CALL(model_.CreateDesc(), SY_SUCCESS, SY_FAILED);//获取模型的描述信息
    ACL_CALL(model_.CreateOutput(outDims_), SY_SUCCESS, SY_FAILED);
    ACL_CALL(model_.GetInputDims(inDims_), SY_SUCCESS, SY_FAILED);
    modelHeight_ = inDims_[0][1];
    modelWidth_ = inDims_[0][2];

    return SY_SUCCESS;
}


int ROADExtract::Inference(ImageData& input) {
    model_.CreateInput(input.data.get(), input.size);
    ACL_CALL(model_.Execute(), SY_SUCCESS, SY_FAILED);
    model_.DestroyInput(); //需调用CreateInput的销毁类接口DestroyInput!!!
    return SY_SUCCESS;
}

int ROADExtract::GetInputWidth() {
    return modelWidth_;
}

int ROADExtract::GetInputHeight() {
    return modelHeight_;
}

int ROADExtract::PostProcess(vector<float>& results) {
    aclmdlDataset* modelOutput = model_.GetModelOutputData();
    int outDatasetNum = aclmdlGetDatasetNumBuffers(modelOutput);
    for (int i = 0; i < outDatasetNum; i++) {
        aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(modelOutput, i);
        if (dataBuffer == nullptr) {
            return SY_FAILED;
        }
      
        uint32_t dataBufferSize = aclGetDataBufferSize(dataBuffer);
        void* data = aclGetDataBufferAddr(dataBuffer);
        if (data == nullptr) {
            return SY_FAILED;
        }
       
        int length = dataBufferSize/sizeof(float); 
        // float outInfo[length]; //栈上分配,受栈大小限制,数组过大可能导致栈溢出
        float* outInfo = new float[length]; //堆上分配
        if (runMode_ == ACL_HOST) {
            ACL_CALL(aclrtMemcpy(outInfo, dataBufferSize, data, dataBufferSize, ACL_MEMCPY_DEVICE_TO_HOST),ACL_SUCCESS, SY_FAILED);
        } 
        else {
            ACL_CALL(aclrtMemcpy(outInfo, dataBufferSize, data, dataBufferSize, ACL_MEMCPY_DEVICE_TO_DEVICE),ACL_SUCCESS, SY_FAILED);
        }

        for(uint32_t b = 0; b < length; b++) {
            results.emplace_back(outInfo[b]);
        }

        delete[] outInfo; outInfo = nullptr;
    }
  
    return SY_SUCCESS;
}

void ROADExtract::Release() {
    model_.Unload();
    model_.DestroyDesc();
    model_.DestroyOutput();
}

}