permute.comp 998 Bytes
#version 450
#define LOCAL_SZ_X 256

layout(push_constant) uniform pushBlock {
      int nthreads;
      int num_axes;
      int global_size;
} p;

layout(binding = 0) readonly buffer Input0{
    float in_buffer[];
};
layout(binding = 1) readonly buffer Input1{
    int permute_order[];
};
layout(binding = 2) readonly buffer Input2{
    int old_stride[];
};
layout(binding = 3) readonly buffer Input3{
    int new_stride[];
};
layout(binding = 4) writeonly buffer Output{
    float out_buffer[];
};

layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in;

void main()
{
    for (int i = int(gl_GlobalInvocationID.x); i < p.nthreads; i += p.global_size)
    {
        int old_pos = 0;
        int new_pos = i;

        for (int j = 0; j < p.num_axes; ++j)
        {
            int order = permute_order[j];
            old_pos += (new_pos / new_stride[j]) * old_stride[order];
            new_pos %= new_stride[j];
        }

        out_buffer[i] = in_buffer[old_pos];
    }
}