#version 450 #define LOCAL_SZ_X 256 layout(push_constant) uniform pushBlock { int in_h; int in_w; int out_h; int out_w; int stride_h; int stride_w; int pad_h; int pad_w; int filter_h; int filter_w; int dilation_h; int dilation_w; int channels; int batch; int has_bias; int M; int K; int N; int basic_shader_batch_idx; int basic_shader_partition_idx; int basic_shader_partition_size; } p; layout(binding = 0) readonly buffer Input0{ float in_buffer[]; }; layout(binding = 1) readonly buffer Input1 { float bias_data[]; }; layout(binding = 2) readonly buffer Input3{ float weight_data[]; }; layout(binding = 3) writeonly buffer Output{ float out_buffer[]; }; layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in; /* Each work item compute one output cell */ void main() { int gx = int(gl_GlobalInvocationID.x); int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); if(gx < p.out_w && gy < p.out_h && gz < p.channels) { float sum = 0.0f; int org_y = gy * p.stride_h - p.pad_h; int org_x = gx * p.stride_w - p.pad_w; int weight_off = gz * p.filter_h * p.filter_w; int input_off = (p.basic_shader_batch_idx * p.channels + gz) * p.in_h * p.in_w + org_y * p.in_w + org_x; for(int y = 0; y < p.filter_h; y++) { for(int x = 0; x < p.filter_w; x++) { if(org_y + y * p.dilation_h >= 0 && org_y + y * p.dilation_h < p.in_h && org_x + x * p.dilation_w >= 0 && org_x + x * p.dilation_w < p.in_w) { sum += in_buffer[input_off + x * p.dilation_w] * weight_data[weight_off + x]; } } weight_off += p.filter_w; input_off += p.in_w * p.dilation_h; } int offset = (p.basic_shader_batch_idx * p.channels + gz) * p.out_h * p.out_w + gy * p.out_w + gx; if (p.has_bias == 1) out_buffer[offset] = sum + bias_data[gz]; else out_buffer[offset] = sum; } }