conv48.comp 5.87 KB
#version 450

layout (constant_id = 0) const int LOCAL_SZ_X = 0;
layout (constant_id = 1) const int LOCAL_SZ_Y = 0;
layout (constant_id = 2) const int LOCAL_SZ_Z = 0;
layout (constant_id = 3) const int IN_H = 0;
layout (constant_id = 4) const int IN_W = 0;
layout (constant_id = 5) const int OUT_W = 0;
layout (constant_id = 6) const int STRIDE_H = 0;
layout (constant_id = 7) const int STRIDE_W = 0;
layout (constant_id = 8) const int PAD_H = 0;
layout (constant_id = 9) const int PAD_W = 0;
layout (constant_id = 10) const int FILTER_H = 0;
layout (constant_id = 11) const int FILTER_W = 0;
layout (constant_id = 12) const int CHANNELS = 0;
layout (constant_id = 13) const int BATCH = 0;
layout (constant_id = 14) const int M = 0;
layout (constant_id = 15) const int K = 0;
layout (constant_id = 16) const int N = 0;
layout (constant_id = 17) const int TAIL_M = 0;
layout (constant_id = 18) const int DILATION_H = 0;
layout (constant_id = 19) const int DILATION_W = 0;

#if defined(ACTIVATION_RELU)
#define ACTIVATION_FUNCTION(x)  clamp(x, vec4(0.0), vec4(999999999.0))
#elif defined(ACTIVATION_RELU1)
#define ACTIVATION_FUNCTION(x)  clamp(x, vec4(-1.0), vec4(1.0))
#elif defined(ACTIVATION_RELU6)
#define ACTIVATION_FUNCTION(x)  clamp(x, vec4(0.0), vec4(6.0))
#else
#define ACTIVATION_FUNCTION(x)  (x)
#endif

layout(binding = 0) readonly buffer Input0{
    float data[];
} src0;
layout(binding = 1) readonly buffer Input1 {
    vec4 data[];
} bias;
layout(binding = 2) readonly buffer Input3{
    vec4 data[];
} src1;
layout(binding = 3) writeonly buffer Output{
    vec4 data[];
} out0;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#define VEC_SIZE 4
#define BLOCK_H 4
#define BLOCK_W 8
#define FILTER_AREA (FILTER_H * FILTER_W)
#define LOAD_A(elm_idx, a_component) \
            src0_x = org_x + ((i * VEC_SIZE + elm_idx) % FILTER_W) * DILATION_W; \
            src0_y = org_y + (((i * VEC_SIZE + elm_idx) % FILTER_AREA) / FILTER_W) * DILATION_H; \
            src0_z = (i * VEC_SIZE + elm_idx) / FILTER_AREA; \
            if(src0_y >= 0 && src0_y < IN_H && src0_x >= 0 && src0_x < IN_W) \
            { \
                a_component = src0.data[input_batch_offset + src0_z * (IN_H * IN_W) + src0_y * IN_W + src0_x]; \
            }

#define A_MULTIPLY_BTILE(a, sliver_num, comp) \
            dst_x = (out_y + sliver_num) % OUT_W; \
            dst_y = (out_y + sliver_num) / OUT_W; \
            org_y = dst_y * STRIDE_H - PAD_H; \
            org_x = dst_x * STRIDE_W - PAD_W; \
            LOAD_A(0, a.x); \
            LOAD_A(1, a.y); \
            LOAD_A(2, a.z); \
            LOAD_A(3, a.w); \
            dot0.comp += dot(brow0, a); \
            dot1.comp += dot(brow1, a); \
            dot2.comp += dot(brow2, a); \
            dot3.comp += dot(brow3, a); \
            dot4.comp += dot(brow4, a); \
            dot5.comp += dot(brow5, a); \
            dot6.comp += dot(brow6, a); \
            dot7.comp += dot(brow7, a);

void main()
{
    int gx = int(gl_GlobalInvocationID.x);
    int gy = int(gl_GlobalInvocationID.y);
    int gz = int(gl_GlobalInvocationID.z);
    int out_x = BLOCK_W * gx;
    int out_y = BLOCK_H * gy;
    int input_batch_offset  = gz * IN_H * IN_W * CHANNELS;
    int output_batch_offset = gz * M * N / VEC_SIZE;
    if (out_x < N && gy < M / BLOCK_H)
    {
        int width0 = K / VEC_SIZE;
        int width1 = N / VEC_SIZE;
        int src1_read0_offset = out_x * width0;
        vec4 dot0 = vec4(0.f);
        vec4 dot1 = vec4(0.f);
        vec4 dot2 = vec4(0.f);
        vec4 dot3 = vec4(0.f);
        vec4 dot4 = vec4(0.f);
        vec4 dot5 = vec4(0.f);
        vec4 dot6 = vec4(0.f);
        vec4 dot7 = vec4(0.f);
        int i = 0;
        do
        {
            int dst_x, dst_y, org_x, org_y, src0_x, src0_y, src0_z;
            vec4 a0 = vec4(0.f), a1 = vec4(0.f), a2 = vec4(0.f), a3 = vec4(0.f);
            vec4 brow0 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
            vec4 brow1 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
            vec4 brow2 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
            vec4 brow3 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
            vec4 brow4 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
            vec4 brow5 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
            vec4 brow6 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
            vec4 brow7 = src1.data[src1_read0_offset]; src1_read0_offset += width0;
            src1_read0_offset += 1 - BLOCK_W * width0;

            A_MULTIPLY_BTILE(a0, 0, x);
            A_MULTIPLY_BTILE(a1, 1, y);
            A_MULTIPLY_BTILE(a2, 2, z);
            A_MULTIPLY_BTILE(a3, 3, w);
            i++;
        }
        while( i < width0 );

        vec4 bias_val;
        bias_val = bias.data[2 * int(gl_GlobalInvocationID.x)];
        dot0 += bias_val.xxxx; dot1 += bias_val.yyyy; dot2 += bias_val.zzzz; dot3 += bias_val.wwww;
        bias_val = bias.data[2 * int(gl_GlobalInvocationID.x) + 1];
        dot4 += bias_val.xxxx; dot5 += bias_val.yyyy; dot6 += bias_val.zzzz; dot7 += bias_val.wwww;

        out0.data[output_batch_offset + (out_x + 0) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot0);
        out0.data[output_batch_offset + (out_x + 1) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot1);
        out0.data[output_batch_offset + (out_x + 2) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot2);
        out0.data[output_batch_offset + (out_x + 3) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot3);
        out0.data[output_batch_offset + (out_x + 4) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot4);
        out0.data[output_batch_offset + (out_x + 5) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot5);
        out0.data[output_batch_offset + (out_x + 6) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot6);
        out0.data[output_batch_offset + (out_x + 7) * M / VEC_SIZE + gy] = ACTIVATION_FUNCTION(dot7);
    }
}