Blame view

src/common/cnn/cnn_extractor.cpp 2.95 KB
f171c20a   Hu Chunming   添加moter_rainshed ...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  #include "cnn_extractor.h"
  #include <iostream>
  #include "acl/acl.h"
  #include "model_process.h"
  #include "sy_errorinfo.h"
  #include <time.h>
  #include <sys/time.h>
  #include <algorithm>
  
  using namespace std;
  
  namespace atlas_utils {
  
  int CNNExtract::Init(const char* modelPath) {
      ACL_CALL(aclrtGetRunMode(&runMode_), SY_SUCCESS, SY_FAILED);//获取当前昇腾AI软件栈的运行模式,根据不同的运行模式,后续的接口调用方式不同
      ACL_CALL(model_.LoadModelFromFileWithMem(modelPath), SY_SUCCESS, SY_FAILED);//从文件加载离线模型数据,由用户自行管理模型运行的内存
      ACL_CALL(model_.CreateDesc(), SY_SUCCESS, SY_FAILED);//获取模型的描述信息
      ACL_CALL(model_.CreateOutput(outDims_), SY_SUCCESS, SY_FAILED);
      ACL_CALL(model_.GetInputDims(inDims_), SY_SUCCESS, SY_FAILED);
      modelHeight_ = inDims_[0][1];
      modelWidth_ = inDims_[0][2];
  
      return SY_SUCCESS;
  }
  
  
  int CNNExtract::Inference(ImageData& input) {
      model_.CreateInput(input.data.get(), input.size);
      ACL_CALL(model_.Execute(), SY_SUCCESS, SY_FAILED);
      model_.DestroyInput(); //需调用CreateInput的销毁类接口DestroyInput!!!
      return SY_SUCCESS;
  }
  
  int CNNExtract::GetInputWidth() {
      return modelWidth_;
  }
  
  int CNNExtract::GetInputHeight() {
      return modelHeight_;
  }
  
  int CNNExtract::PostProcess(vector<float>& results) {
      aclmdlDataset* modelOutput = model_.GetModelOutputData();
      int outDatasetNum = aclmdlGetDatasetNumBuffers(modelOutput);
      for (int i = 0; i < outDatasetNum; i++) {
          aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(modelOutput, i);
          if (dataBuffer == nullptr) {
              return SY_FAILED;
          }
          uint32_t dataBufferSize = aclGetDataBufferSize(dataBuffer);
          void* data = aclGetDataBufferAddr(dataBuffer);
          if (data == nullptr) {
              return SY_FAILED;
          }
          
          int length = dataBufferSize/sizeof(float); 
          float outInfo[length];
          
          if (runMode_ == ACL_HOST) {
              ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_HOST),
                       ACL_SUCCESS, SY_FAILED);
              } else {
                       ACL_CALL(aclrtMemcpy(outInfo, sizeof(outInfo), data, sizeof(outInfo), ACL_MEMCPY_DEVICE_TO_DEVICE),ACL_SUCCESS, SY_FAILED);
                 //return SY_FAILED;
              }
  
              // //归一化
              // float sum = 0.0;
              // for(int j=0;j<length;j++){
              //     sum = sum + outInfo[j]*outInfo[j];
              // }
              // sum = sqrt(sum) + 1e-6;
              // for(int j=0;j<length;j++){
              //     results.emplace_back(outInfo[j]/sum);
              // }
  
               
              for(uint32_t b = 0; b < length; b++) {
                  results.emplace_back(outInfo[b]);
              }
        
      }
  
      
      return SY_SUCCESS;
  }
  
  void CNNExtract::Release() {
      model_.Unload();
      model_.DestroyDesc();
      model_.DestroyOutput();
  }
  
  }