lrn.comp
2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#version 450
#define LOCAL_SZ_X 256
layout(push_constant) uniform pushBlock {
int thread_num;
int channels;
int height;
int width;
int filter_len;
int radius;
float alpha;
float bias;
float negative_beta;
} p;
layout(binding = 0) readonly buffer Input0{
float in_buffer[];
};
layout(binding = 1) writeonly buffer Output{
float dst_buffer[];
};
layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in;
void main()
{
int gid = int(gl_GlobalInvocationID.x);
int gsz = int(gl_NumWorkGroups.x * gl_WorkGroupSize.x);
for (int index = gid; index < p.thread_num; index += gsz)
{
int x = index % p.width;
int y = (index / p.width) % p.height;
int b = index / (p.width * p.height);
int offset = b * p.channels * p.height * p.width + y * p.width + x;
int channel_off = p.height * p.width;
float scale_val;
int head = 0;
float accum_scale = 0.0f;
int min_val = p.radius < p.channels ? p.radius : p.channels;
while (head < min_val) {
accum_scale += in_buffer[offset + head * channel_off] * in_buffer[offset + head * channel_off];
++head;
}
while (head < p.channels) {
accum_scale += in_buffer[offset + head * channel_off] * in_buffer[offset + head * channel_off];
if (head - p.filter_len >= 0) {
accum_scale -= in_buffer[offset + (head - p.filter_len) * channel_off]
* in_buffer[offset + (head - p.filter_len) * channel_off];
}
scale_val = p.bias + accum_scale * p.alpha;
dst_buffer[offset + (head - p.radius) * channel_off] = in_buffer[offset + (head - p.radius) * channel_off] * pow(scale_val, p.negative_beta);
++head;
}
int pos = head - min_val;
while (pos >= 0 && pos < p.channels) {
if (head - p.filter_len >= 0) {
accum_scale -= in_buffer[offset + (head - p.filter_len) * channel_off]
* in_buffer[offset + (head - p.filter_len) * channel_off];
}
scale_val = p.bias + accum_scale * p.alpha;
dst_buffer[offset + pos * channel_off] = in_buffer[offset + pos * channel_off] * pow(scale_val, p.negative_beta);
++head;
++pos;
}
}
}