conv48.comp
5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#version 450
layout (constant_id = 0) const int LOCAL_SZ_X = 0;
layout (constant_id = 1) const int LOCAL_SZ_Y = 0;
layout (constant_id = 2) const int LOCAL_SZ_Z = 0;
layout (constant_id = 3) const int IN_H = 0;
layout (constant_id = 4) const int IN_W = 0;
layout (constant_id = 5) const int OUT_W = 0;
layout (constant_id = 6) const int STRIDE_H = 0;
layout (constant_id = 7) const int STRIDE_W = 0;
layout (constant_id = 8) const int PAD_H = 0;
layout (constant_id = 9) const int PAD_W = 0;
layout (constant_id = 10) const int FILTER_H = 0;
layout (constant_id = 11) const int FILTER_W = 0;
layout (constant_id = 12) const int CHANNELS = 0;
layout (constant_id = 13) const int BATCH = 0;
layout (constant_id = 14) const int M = 0;
layout (constant_id = 15) const int K = 0;
layout (constant_id = 16) const int N = 0;
layout (constant_id = 17) const int TAIL_M = 0;
layout (constant_id = 18) const int DILATION_H = 0;
layout (constant_id = 19) const int DILATION_W = 0;
#if defined(ACTIVATION_RELU)
#define ACTIVATION_FUNCTION(x) clamp(x, vec4(0.0), vec4(999999999.0))
#elif defined(ACTIVATION_RELU1)
#define ACTIVATION_FUNCTION(x) clamp(x, vec4(-1.0), vec4(1.0))
#elif defined(ACTIVATION_RELU6)
#define ACTIVATION_FUNCTION(x) clamp(x, vec4(0.0), vec4(6.0))
#else
#define ACTIVATION_FUNCTION(x) (x)
#endif
layout(binding = 0) readonly buffer Input0{
float data[];
} src0;
layout(binding = 1) readonly buffer Input1 {
vec4 data[];
} bias;
layout(binding = 2) readonly buffer Input3{
vec4 data[];
} src1;
layout(binding = 3) writeonly buffer Output{
vec4 data[];
} out0;
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
#define VEC_SIZE 4
#define BLOCK_H 4
#define BLOCK_W 8
#define FILTER_AREA (FILTER_H * FILTER_W)
#define LOAD_A(elm_idx, a_component) \
src0_x = org_x + ((i * VEC_SIZE + elm_idx) % FILTER_W) * DILATION_W; \
src0_y = org_y + (((i * VEC_SIZE + elm_idx) % FILTER_AREA) / FILTER_W) * DILATION_H; \
src0_z = (i * VEC_SIZE + elm_idx) / FILTER_AREA; \
if(src0_y >= 0 && src0_y < IN_H && src0_x >= 0 && src0_x < IN_W) \
{ \
a_component = src0.data[input_batch_offset + src0_z * (IN_H * IN_W) + src0_y * IN_W + src0_x]; \
}
#define A_MULTIPLY_BTILE(a, sliver_num, comp) \
dst_x = (out_y + sliver_num) % OUT_W; \
dst_y = (out_y + sliver_num) / OUT_W; \
org_y = dst_y * STRIDE_H - PAD_H; \
org_x = dst_x * STRIDE_W - PAD_W; \
LOAD_A(0, a.x); \
LOAD_A(1, a.y); \
LOAD_A(2, a.z); \
LOAD_A(3, a.w); \
dot0.comp += dot(brow0, a); \
dot1.comp += dot(brow1, a); \
dot2.comp += dot(brow2, a); \
dot3.comp += dot(brow3, a); \
dot4.comp += dot(brow4, a); \
dot5.comp += dot(brow5, a); \
dot6.comp += dot(brow6, a); \
dot7.comp += dot(brow7, a);
void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);
int out_x = BLOCK_W * gx;
int out_y = BLOCK_H * gy;
int input_batch_offset = gz * IN_H * IN_W * CHANNELS;
int output_batch_offset = gz * M * N / VEC_SIZE;
if (out_x < N && gy < M / BLOCK_H)
{
int width0 = K / VEC_SIZE;
int width1 = N / VEC_SIZE;
int src1_read0_offset = out_x * width0;
vec4 dot0 = vec4(0.f);
vec4 dot1 = vec4(0.f);
vec4 dot2 = vec4(0.f);
vec4 dot3 = vec4(0.f);
vec4 dot4 = vec4(0.f);
vec4 dot5 = vec4(0.f);
vec4 dot6 = vec4(0.f);
vec4 dot7 = vec4(0.f);
int i = 0;
do
{
int dst_x, dst_y, org_x, org_y, src0_x, src0_y, src0_z;
vec4 a0 = vec4(0.f), a1 = vec4(0.f), a2 = vec4(0.f), a3 = vec4(0.f);
vec4 brow0 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
vec4 brow1 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
vec4 brow2 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
vec4 brow3 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
vec4 brow4 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
vec4 brow5 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
vec4 brow6 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
vec4 brow7 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
src1_read0_offset += 1 - BLOCK_W * width0;
A_MULTIPLY_BTILE(a0, 0, x);
A_MULTIPLY_BTILE(a1, 1, y);
A_MULTIPLY_BTILE(a2, 2, z);
A_MULTIPLY_BTILE(a3, 3, w);
i++;
}
while( i < width0 );
vec4 bias_val;
bias_val = bias.data[2 * int(gl_GlobalInvocationID.x)];
dot0 += bias_val.xxxx; dot1 += bias_val.yyyy; dot2 += bias_val.zzzz; dot3 += bias_val.wwww;
bias_val = bias.data[2 * int(gl_GlobalInvocationID.x) + 1];
dot4 += bias_val.xxxx; dot5 += bias_val.yyyy; dot6 += bias_val.zzzz; dot7 += bias_val.wwww;
out0.data[output_batch_offset + (out_x + 0) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot0);
out0.data[output_batch_offset + (out_x + 1) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot1);
out0.data[output_batch_offset + (out_x + 2) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot2);
out0.data[output_batch_offset + (out_x + 3) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot3);
out0.data[output_batch_offset + (out_x + 4) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot4);
out0.data[output_batch_offset + (out_x + 5) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot5);
out0.data[output_batch_offset + (out_x + 6) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot6);
out0.data[output_batch_offset + (out_x + 7) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot7);
}
}