#version 450 layout (constant_id = 0) const int LOCAL_SZ_X = 0; layout (constant_id = 1) const int LOCAL_SZ_Y = 0; layout (constant_id = 2) const int LOCAL_SZ_Z = 0; layout (constant_id = 3) const int IN_H = 0; layout (constant_id = 4) const int IN_W = 0; layout (constant_id = 5) const int OUT_W = 0; layout (constant_id = 6) const int STRIDE_H = 0; layout (constant_id = 7) const int STRIDE_W = 0; layout (constant_id = 8) const int PAD_H = 0; layout (constant_id = 9) const int PAD_W = 0; layout (constant_id = 10) const int FILTER_H = 0; layout (constant_id = 11) const int FILTER_W = 0; layout (constant_id = 12) const int CHANNELS = 0; layout (constant_id = 13) const int BATCH = 0; layout (constant_id = 14) const int M = 0; layout (constant_id = 15) const int K = 0; layout (constant_id = 16) const int N = 0; layout (constant_id = 17) const int TAIL_M = 0; layout (constant_id = 18) const int DILATION_H = 0; layout (constant_id = 19) const int DILATION_W = 0; #if defined(ACTIVATION_RELU) #define ACTIVATION_FUNCTION(x) clamp(x, vec4(0.0), vec4(999999999.0)) #elif defined(ACTIVATION_RELU1) #define ACTIVATION_FUNCTION(x) clamp(x, vec4(-1.0), vec4(1.0)) #elif defined(ACTIVATION_RELU6) #define ACTIVATION_FUNCTION(x) clamp(x, vec4(0.0), vec4(6.0)) #else #define ACTIVATION_FUNCTION(x) (x) #endif layout(binding = 0) readonly buffer Input0{ float data[]; } src0; layout(binding = 1) readonly buffer Input1 { vec4 data[]; } bias; layout(binding = 2) readonly buffer Input3{ vec4 data[]; } src1; layout(binding = 3) writeonly buffer Output{ vec4 data[]; } out0; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #define VEC_SIZE 4 #define BLOCK_H 4 #define BLOCK_W 8 #define FILTER_AREA (FILTER_H * FILTER_W) #define LOAD_A(elm_idx, a_component) \ src0_x = org_x + ((i * VEC_SIZE + elm_idx) % FILTER_W) * DILATION_W; \ src0_y = org_y + (((i * VEC_SIZE + elm_idx) % FILTER_AREA) / FILTER_W) * DILATION_H; \ src0_z = (i * VEC_SIZE + elm_idx) / FILTER_AREA; \ if(src0_y >= 0 && src0_y < IN_H && src0_x >= 0 && src0_x < IN_W) \ { \ a_component = src0.data[input_batch_offset + src0_z * (IN_H * IN_W) + src0_y * IN_W + src0_x]; \ } #define A_MULTIPLY_BTILE(a, sliver_num, comp) \ dst_x = (out_y + sliver_num) % OUT_W; \ dst_y = (out_y + sliver_num) / OUT_W; \ org_y = dst_y * STRIDE_H - PAD_H; \ org_x = dst_x * STRIDE_W - PAD_W; \ LOAD_A(0, a.x); \ LOAD_A(1, a.y); \ LOAD_A(2, a.z); \ LOAD_A(3, a.w); \ dot0.comp += dot(brow0, a); \ dot1.comp += dot(brow1, a); \ dot2.comp += dot(brow2, a); \ dot3.comp += dot(brow3, a); \ dot4.comp += dot(brow4, a); \ dot5.comp += dot(brow5, a); \ dot6.comp += dot(brow6, a); \ dot7.comp += dot(brow7, a); void main() { int gx = int(gl_GlobalInvocationID.x); int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); int out_x = BLOCK_W * gx; int out_y = BLOCK_H * gy; int input_batch_offset = gz * IN_H * IN_W * CHANNELS; int output_batch_offset = gz * M * N / VEC_SIZE; if (out_x < N && gy < M / BLOCK_H) { int width0 = K / VEC_SIZE; int width1 = N / VEC_SIZE; int src1_read0_offset = out_x * width0; vec4 dot0 = vec4(0.f); vec4 dot1 = vec4(0.f); vec4 dot2 = vec4(0.f); vec4 dot3 = vec4(0.f); vec4 dot4 = vec4(0.f); vec4 dot5 = vec4(0.f); vec4 dot6 = vec4(0.f); vec4 dot7 = vec4(0.f); int i = 0; do { int dst_x, dst_y, org_x, org_y, src0_x, src0_y, src0_z; vec4 a0 = vec4(0.f), a1 = vec4(0.f), a2 = vec4(0.f), a3 = vec4(0.f); vec4 brow0 = src1.data[src1_read0_offset]; src1_read0_offset += width0; vec4 brow1 = src1.data[src1_read0_offset]; src1_read0_offset += width0; vec4 brow2 = src1.data[src1_read0_offset]; src1_read0_offset += width0; vec4 brow3 = src1.data[src1_read0_offset]; src1_read0_offset += width0; vec4 brow4 = src1.data[src1_read0_offset]; src1_read0_offset += width0; vec4 brow5 = src1.data[src1_read0_offset]; src1_read0_offset += width0; vec4 brow6 = src1.data[src1_read0_offset]; src1_read0_offset += width0; vec4 brow7 = src1.data[src1_read0_offset]; src1_read0_offset += width0; src1_read0_offset += 1 - BLOCK_W * width0; A_MULTIPLY_BTILE(a0, 0, x); A_MULTIPLY_BTILE(a1, 1, y); A_MULTIPLY_BTILE(a2, 2, z); A_MULTIPLY_BTILE(a3, 3, w); i++; } while( i < width0 ); vec4 bias_val; bias_val = bias.data[2 * int(gl_GlobalInvocationID.x)]; dot0 += bias_val.xxxx; dot1 += bias_val.yyyy; dot2 += bias_val.zzzz; dot3 += bias_val.wwww; bias_val = bias.data[2 * int(gl_GlobalInvocationID.x) + 1]; dot4 += bias_val.xxxx; dot5 += bias_val.yyyy; dot6 += bias_val.zzzz; dot7 += bias_val.wwww; out0.data[output_batch_offset + (out_x + 0) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot0); out0.data[output_batch_offset + (out_x + 1) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot1); out0.data[output_batch_offset + (out_x + 2) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot2); out0.data[output_batch_offset + (out_x + 3) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot3); out0.data[output_batch_offset + (out_x + 4) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot4); out0.data[output_batch_offset + (out_x + 5) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot5); out0.data[output_batch_offset + (out_x + 6) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot6); out0.data[output_batch_offset + (out_x + 7) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot7); } }