#version 450 #define LOCAL_SZ_X 256 layout(push_constant) uniform pushBlock { int out_concat_axis; int accumulated_concat_axis; int concat_size; int total_concat_size; int thread_num; } p; layout(binding = 0) readonly buffer Input0{ float data[]; } src; layout(binding = 1) writeonly buffer Output{ float data[]; } dst; layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in; void main() { int index = int(gl_GlobalInvocationID.x); if (index < p.thread_num) { int concat_num = index / p.total_concat_size; int concat_index = index % p.total_concat_size; int out_index = concat_index + (concat_num * p.out_concat_axis + p.accumulated_concat_axis) * p.concat_size; dst.data[out_index] = src.data[index]; } }