f4334277
Hu Chunming
提交3rdparty
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
|
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP
#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP
#include <cudnn.h>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn {
class ActivationDescriptor {
public:
enum class ActivationType {
IDENTITY,
RELU,
CLIPPED_RELU,
TANH,
SIGMOID,
ELU
};
ActivationDescriptor() noexcept : descriptor{ nullptr } { }
ActivationDescriptor(const ActivationDescriptor&) = delete;
ActivationDescriptor(ActivationDescriptor&& other) noexcept
: descriptor{ other.descriptor } {
other.descriptor = nullptr;
}
/* `relu_ceiling_or_elu_alpha`:
* - `alpha` coefficient in ELU activation
* - `ceiling` for CLIPPED_RELU activation
*/
ActivationDescriptor(ActivationType type, double relu_ceiling_or_elu_alpha = 0.0) {
CUDA4DNN_CHECK_CUDNN(cudnnCreateActivationDescriptor(&descriptor));
try {
const auto mode = [type] {
switch(type) {
case ActivationType::IDENTITY: return CUDNN_ACTIVATION_IDENTITY;
case ActivationType::RELU: return CUDNN_ACTIVATION_RELU;
case ActivationType::CLIPPED_RELU: return CUDNN_ACTIVATION_CLIPPED_RELU;
case ActivationType::SIGMOID: return CUDNN_ACTIVATION_SIGMOID;
case ActivationType::TANH: return CUDNN_ACTIVATION_TANH;
case ActivationType::ELU: return CUDNN_ACTIVATION_ELU;
}
CV_Assert(0);
return CUDNN_ACTIVATION_IDENTITY;
} ();
CUDA4DNN_CHECK_CUDNN(cudnnSetActivationDescriptor(descriptor, mode, CUDNN_NOT_PROPAGATE_NAN, relu_ceiling_or_elu_alpha));
} catch(...) {
/* cudnnDestroyActivationDescriptor will not fail for a valid descriptor object */
CUDA4DNN_CHECK_CUDNN(cudnnDestroyActivationDescriptor(descriptor));
throw;
}
}
~ActivationDescriptor() noexcept {
if (descriptor != nullptr) {
/* cudnnDestroyActivationDescriptor will not fail */
CUDA4DNN_CHECK_CUDNN(cudnnDestroyActivationDescriptor(descriptor));
}
}
ActivationDescriptor& operator=(const ActivationDescriptor&) = delete;
ActivationDescriptor& operator=(ActivationDescriptor&& other) noexcept {
descriptor = other.descriptor;
other.descriptor = nullptr;
return *this;
};
cudnnActivationDescriptor_t get() const noexcept { return descriptor; }
private:
cudnnActivationDescriptor_t descriptor;
};
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP */
|