Blame view

src/common/road_cnn/road_extractor.cpp 2.71 KB
2ae58093   Hu Chunming   添加road_seg算法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
  #include "road_extractor.h"
  #include <iostream>
  #include "acl/acl.h"
  #include "model_process.h"
  #include "sy_errorinfo.h"
  #include <time.h>
  #include <sys/time.h>
  #include <algorithm>
  
  using namespace std;
  
  namespace atlas_utils {
  
  int ROADExtract::Init(const char* modelPath) {
      ACL_CALL(aclrtGetRunMode(&runMode_), SY_SUCCESS, SY_FAILED);//获取当前昇腾AI软件栈的运行模式,根据不同的运行模式,后续的接口调用方式不同
      ACL_CALL(model_.LoadModelFromFileWithMem(modelPath), SY_SUCCESS, SY_FAILED);//从文件加载离线模型数据,由用户自行管理模型运行的内存
      ACL_CALL(model_.CreateDesc(), SY_SUCCESS, SY_FAILED);//获取模型的描述信息
      ACL_CALL(model_.CreateOutput(outDims_), SY_SUCCESS, SY_FAILED);
      ACL_CALL(model_.GetInputDims(inDims_), SY_SUCCESS, SY_FAILED);
      modelHeight_ = inDims_[0][1];
      modelWidth_ = inDims_[0][2];
  
      return SY_SUCCESS;
  }
  
  
  int ROADExtract::Inference(ImageData& input) {
      model_.CreateInput(input.data.get(), input.size);
      ACL_CALL(model_.Execute(), SY_SUCCESS, SY_FAILED);
      model_.DestroyInput(); //需调用CreateInput的销毁类接口DestroyInput!!!
      return SY_SUCCESS;
  }
  
  int ROADExtract::GetInputWidth() {
      return modelWidth_;
  }
  
  int ROADExtract::GetInputHeight() {
      return modelHeight_;
  }
  
  int ROADExtract::PostProcess(vector<float>& results) {
      aclmdlDataset* modelOutput = model_.GetModelOutputData();
      int outDatasetNum = aclmdlGetDatasetNumBuffers(modelOutput);
      for (int i = 0; i < outDatasetNum; i++) {
          aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(modelOutput, i);
          if (dataBuffer == nullptr) {
              return SY_FAILED;
          }
        
          uint32_t dataBufferSize = aclGetDataBufferSize(dataBuffer);
          void* data = aclGetDataBufferAddr(dataBuffer);
          if (data == nullptr) {
              return SY_FAILED;
          }
         
          int length = dataBufferSize/sizeof(float); 
          // float outInfo[length]; //栈上分配,受栈大小限制,数组过大可能导致栈溢出
          float* outInfo = new float[length]; //堆上分配
          if (runMode_ == ACL_HOST) {
              ACL_CALL(aclrtMemcpy(outInfo, dataBufferSize, data, dataBufferSize, ACL_MEMCPY_DEVICE_TO_HOST),ACL_SUCCESS, SY_FAILED);
          } 
          else {
              ACL_CALL(aclrtMemcpy(outInfo, dataBufferSize, data, dataBufferSize, ACL_MEMCPY_DEVICE_TO_DEVICE),ACL_SUCCESS, SY_FAILED);
          }
  
          for(uint32_t b = 0; b < length; b++) {
              results.emplace_back(outInfo[b]);
          }
  
          delete[] outInfo; outInfo = nullptr;
      }
    
      return SY_SUCCESS;
  }
  
  void ROADExtract::Release() {
      model_.Unload();
      model_.DestroyDesc();
      model_.DestroyOutput();
  }
  
  }