permute.comp
998 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#version 450
#define LOCAL_SZ_X 256
layout(push_constant) uniform pushBlock {
int nthreads;
int num_axes;
int global_size;
} p;
layout(binding = 0) readonly buffer Input0{
float in_buffer[];
};
layout(binding = 1) readonly buffer Input1{
int permute_order[];
};
layout(binding = 2) readonly buffer Input2{
int old_stride[];
};
layout(binding = 3) readonly buffer Input3{
int new_stride[];
};
layout(binding = 4) writeonly buffer Output{
float out_buffer[];
};
layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in;
void main()
{
for (int i = int(gl_GlobalInvocationID.x); i < p.nthreads; i += p.global_size)
{
int old_pos = 0;
int new_pos = i;
for (int j = 0; j < p.num_axes; ++j)
{
int order = permute_order[j];
old_pos += (new_pos / new_stride[j]) * old_stride[order];
new_pos %= new_stride[j];
}
out_buffer[i] = in_buffer[old_pos];
}
}