concat.comp
778 Bytes
#version 450
#define LOCAL_SZ_X 256
layout(push_constant) uniform pushBlock {
int out_concat_axis;
int accumulated_concat_axis;
int concat_size;
int total_concat_size;
int thread_num;
} p;
layout(binding = 0) readonly buffer Input0{
float data[];
} src;
layout(binding = 1) writeonly buffer Output{
float data[];
} dst;
layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in;
void main()
{
int index = int(gl_GlobalInvocationID.x);
if (index < p.thread_num)
{
int concat_num = index / p.total_concat_size;
int concat_index = index % p.total_concat_size;
int out_index = concat_index + (concat_num * p.out_concat_axis + p.accumulated_concat_axis) * p.concat_size;
dst.data[out_index] = src.data[index];
}
}