Blame view

3rdparty/opencv-4.5.4/modules/dnn/src/cuda4dnn/csl/cudnn/activation.hpp 3.12 KB
f4334277   Hu Chunming   提交3rdparty
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  // This file is part of OpenCV project.
  // It is subject to the license terms in the LICENSE file found in the top-level directory
  // of this distribution and at http://opencv.org/license.html.
  
  #ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP
  #define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP
  
  #include <cudnn.h>
  
  namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn {
  
      class ActivationDescriptor {
      public:
          enum class ActivationType {
              IDENTITY,
              RELU,
              CLIPPED_RELU,
              TANH,
              SIGMOID,
              ELU
          };
  
          ActivationDescriptor() noexcept : descriptor{ nullptr } { }
          ActivationDescriptor(const ActivationDescriptor&) = delete;
          ActivationDescriptor(ActivationDescriptor&& other) noexcept
              : descriptor{ other.descriptor } {
              other.descriptor = nullptr;
          }
  
          /* `relu_ceiling_or_elu_alpha`:
           * - `alpha` coefficient in ELU activation
           * - `ceiling` for CLIPPED_RELU activation
           */
          ActivationDescriptor(ActivationType type, double relu_ceiling_or_elu_alpha = 0.0) {
              CUDA4DNN_CHECK_CUDNN(cudnnCreateActivationDescriptor(&descriptor));
              try {
                  const auto mode = [type] {
                      switch(type) {
                          case ActivationType::IDENTITY: return CUDNN_ACTIVATION_IDENTITY;
                          case ActivationType::RELU: return CUDNN_ACTIVATION_RELU;
                          case ActivationType::CLIPPED_RELU: return CUDNN_ACTIVATION_CLIPPED_RELU;
                          case ActivationType::SIGMOID: return CUDNN_ACTIVATION_SIGMOID;
                          case ActivationType::TANH: return CUDNN_ACTIVATION_TANH;
                          case ActivationType::ELU: return CUDNN_ACTIVATION_ELU;
                      }
                      CV_Assert(0);
                      return CUDNN_ACTIVATION_IDENTITY;
                  } ();
  
                  CUDA4DNN_CHECK_CUDNN(cudnnSetActivationDescriptor(descriptor, mode, CUDNN_NOT_PROPAGATE_NAN, relu_ceiling_or_elu_alpha));
              } catch(...) {
                  /* cudnnDestroyActivationDescriptor will not fail for a valid descriptor object */
                  CUDA4DNN_CHECK_CUDNN(cudnnDestroyActivationDescriptor(descriptor));
                  throw;
              }
          }
  
          ~ActivationDescriptor() noexcept {
              if (descriptor != nullptr) {
                  /* cudnnDestroyActivationDescriptor will not fail */
                  CUDA4DNN_CHECK_CUDNN(cudnnDestroyActivationDescriptor(descriptor));
              }
          }
  
          ActivationDescriptor& operator=(const ActivationDescriptor&) = delete;
          ActivationDescriptor& operator=(ActivationDescriptor&& other) noexcept {
              descriptor = other.descriptor;
              other.descriptor = nullptr;
              return *this;
          };
  
          cudnnActivationDescriptor_t get() const noexcept { return descriptor; }
  
      private:
          cudnnActivationDescriptor_t descriptor;
      };
  
  }}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
  
  #endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP */