region.cu
9.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "math.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "limits.hpp"
#include "vector_traits.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include <opencv2/core.hpp>
#include <cstddef>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T>
__global__ void region_box(
Span<T> output, View<T> input, View<T> bias,
size_type boxes_per_cell, size_type box_size,
size_type rows, size_type cols, T scale_x_y,
size_type height_norm, size_type width_norm,
T object_prob_cutoff, bool new_coords)
{
using vector2_type = get_vector_type_t<T, 2>;
auto bias_vPtr = vector2_type::get_pointer(bias.data());
for (auto box_index : grid_stride_range(output.size() / box_size)) {
const auto box_of_the_cell = box_index % boxes_per_cell; /* box number within a cell */
const auto box_offset = box_index * box_size;
const auto batch_inner_size = rows * cols * boxes_per_cell;
const auto row_inner_size = cols * boxes_per_cell;
const auto col_inner_size = boxes_per_cell;
const auto y = (box_index % batch_inner_size) / row_inner_size;
const auto x = (box_index % row_inner_size) / col_inner_size;
/* When new_coords is true, we shouldn't use logistic activation again */
T objectness_prob;
if (new_coords)
{
const auto tmp_x = (input[box_offset + 0] - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
const auto tmp_y = (input[box_offset + 1] - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
output[box_offset + 0] = fast_divide_ftz(static_cast<T>(x) + tmp_x, static_cast<T>(cols));
output[box_offset + 1] = fast_divide_ftz(static_cast<T>(y) + tmp_y, static_cast<T>(rows));
vector2_type bias_xy;
v_load(bias_xy, bias_vPtr[box_of_the_cell]);
output[box_offset + 2] = input[box_offset + 2] * input[box_offset + 2] *
static_cast<T>(4) * bias_xy.data[0] / static_cast<T>(width_norm);
output[box_offset + 3] = input[box_offset + 3] * input[box_offset + 3] *
static_cast<T>(4) * bias_xy.data[1] / static_cast<T>(height_norm);
objectness_prob = input[box_offset + 4];
}
else
{
const auto tmp_x = (fast_sigmoid(input[box_offset + 0]) - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
const auto tmp_y = (fast_sigmoid(input[box_offset + 1]) - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
output[box_offset + 0] = fast_divide_ftz(static_cast<T>(x) + tmp_x, static_cast<T>(cols));
output[box_offset + 1] = fast_divide_ftz(static_cast<T>(y) + tmp_y, static_cast<T>(rows));
vector2_type bias_xy;
v_load(bias_xy, bias_vPtr[box_of_the_cell]);
output[box_offset + 2] = fast_exp(input[box_offset + 2]) * bias_xy.data[0] / static_cast<T>(width_norm);
output[box_offset + 3] = fast_exp(input[box_offset + 3]) * bias_xy.data[1] / static_cast<T>(height_norm);
/* squash objectness score into a probability */
objectness_prob = fast_sigmoid(input[box_offset + 4]);
}
/* ignore prediction if the objectness probability is less than the cutoff */
if (objectness_prob < object_prob_cutoff)
objectness_prob = 0;
output[box_offset + 4] = objectness_prob;
}
}
template <class T>
__global__ void region_sigmoid_class_score(Span<T> output, View<T> input, T class_prob_cutoff,
size_type box_size, bool new_coords)
{
for (auto idx : grid_stride_range(output.size())) {
const index_type box_no = idx / box_size;
const index_type start_of_box = box_no * box_size;
const index_type box_offset = idx % box_size;
if (box_offset < 5) {
/* continue as we have already processed these in region_box */
continue;
}
auto objectness_prob = output[start_of_box + 4];
/* the class probabilities we currently have are conditional class probabilities
* given the object
*
* to obtain the actual class probability, we multiply the conditional probability
* with the object probability
*
* when new_coords is true, we shouldn't use logistic activation again.
*/
T actual_class_prob;
if (new_coords)
{
actual_class_prob = objectness_prob * input[idx];
}
else
{
actual_class_prob = objectness_prob * fast_sigmoid(input[idx]);
}
if (actual_class_prob <= class_prob_cutoff)
actual_class_prob = T(0);
output[idx] = actual_class_prob;
}
}
template <class T>
__global__ void region_softmax_class_score(Span<T> output, View<T> input, T class_prob_cutoff, size_type box_size) {
for (auto box_no : grid_stride_range(output.size() / box_size)) {
const index_type start_of_box = box_no * box_size;
const index_type start_idx = start_of_box + 5;
const index_type end_idx = start_of_box + box_size;
auto largest = numeric_limits<T>::lowest();
for (int idx = start_idx; idx < end_idx; idx++) {
using device::max;
largest = max(largest, input[idx]);
}
auto sum = T(0);
for (int idx = start_idx; idx < end_idx; idx++) {
using device::exp;
auto temp = exp(input[idx] - largest);
sum += temp;
output[idx] = temp;
}
for (int idx = start_idx; idx < end_idx; idx++) {
auto softmax_score = output[idx] / sum;
/* the class probabilities we currently have are conditional class probabilities
* given the object
*
* to obtain the actual class probability, we multiply the conditional probability
* with the object probability
*/
auto objectness_prob = output[start_of_box + 4];
auto actual_class_prob = objectness_prob * softmax_score;
if (actual_class_prob <= class_prob_cutoff)
actual_class_prob = T(0);
output[idx] = actual_class_prob;
}
}
}
}
template <class T>
void region(const Stream& stream, Span<T> output, View<T> input, View<T> bias,
T object_prob_cutoff, T class_prob_cutoff,
std::size_t boxes_per_cell, std::size_t box_size,
std::size_t rows, std::size_t cols, T scale_x_y,
std::size_t height_norm, std::size_t width_norm,
bool if_true_sigmoid_else_softmax, /* true = sigmoid, false = softmax */
bool new_coords)
{
CV_Assert(output.size() == input.size());
CV_Assert(output.size() % box_size == 0);
CV_Assert(is_fully_aligned(bias, 2));
auto box_kernel = raw::region_box<T>;
auto box_policy = make_policy(box_kernel, output.size() / box_size, 0, stream);
launch_kernel(box_kernel, box_policy,
output, input, bias, boxes_per_cell, box_size,
rows, cols, scale_x_y, height_norm, width_norm,
object_prob_cutoff, new_coords);
if (if_true_sigmoid_else_softmax) {
auto kernel_score = raw::region_sigmoid_class_score<T>;
auto policy_score = make_policy(kernel_score, output.size(), 0, stream);
launch_kernel(kernel_score, policy_score, output, input, class_prob_cutoff, box_size, new_coords);
} else {
auto kernel_score = raw::region_softmax_class_score<T>;
auto policy_score = make_policy(kernel_score, output.size(), 0, stream);
launch_kernel(kernel_score, policy_score, output, input, class_prob_cutoff, box_size);
}
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void region(const Stream&, Span<__half>, View<__half>, View<__half>,
__half, __half, std::size_t, std::size_t, std::size_t, std::size_t, __half, std::size_t, std::size_t, bool, bool);
#endif
template void region(const Stream&, Span<float>, View<float>, View<float>,
float, float, std::size_t, std::size_t, std::size_t, std::size_t, float, std::size_t, std::size_t, bool, bool);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */