#version 450 #define LOCAL_SZ_X 256 layout(push_constant) uniform pushBlock { int nthreads; int num_axes; int global_size; } p; layout(binding = 0) readonly buffer Input0{ float in_buffer[]; }; layout(binding = 1) readonly buffer Input1{ int permute_order[]; }; layout(binding = 2) readonly buffer Input2{ int old_stride[]; }; layout(binding = 3) readonly buffer Input3{ int new_stride[]; }; layout(binding = 4) writeonly buffer Output{ float out_buffer[]; }; layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in; void main() { for (int i = int(gl_GlobalInvocationID.x); i < p.nthreads; i += p.global_size) { int old_pos = 0; int new_pos = i; for (int j = 0; j < p.num_axes; ++j) { int order = permute_order[j]; old_pos += (new_pos / new_stride[j]) * old_stride[order]; new_pos %= new_stride[j]; } out_buffer[i] = in_buffer[old_pos]; } }