#version 450 #define LOCAL_SZ_X 256 layout(binding = 0) readonly buffer buf0{ float input_buffer[]; // outer_size * channels * channel_size }; layout(binding = 1) buffer buf1{ float max_buffer[]; // outer_size * channel_size }; layout(binding = 2) buffer buf2{ float sum_buffer[]; // outer_size * channel_size }; layout(binding = 3) buffer buf3{ float output_buffer[]; // outer_size * channels * channel_size }; layout(push_constant) uniform pushBlock { int channel_size; int outer_size; int channels; int logsoftmax; } p; layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in; void main() { int gid = int(gl_GlobalInvocationID.x); if (gid >= p.outer_size) return; int global_off = gid * p.channels * p.channel_size; int reduced_buffer_off = gid * p.channel_size; // find the max along channel int index = global_off; for (int i = 0; i < p.channel_size; ++i) { max_buffer[reduced_buffer_off + i] = input_buffer[index]; index++; } for (int c = 1; c < p.channels; ++c) { for (int i = 0; i < p.channel_size; ++i) { max_buffer[reduced_buffer_off + i] = max(max_buffer[reduced_buffer_off + i], input_buffer[index]); index++; } } // subtract, exp and accumulate along channel for (int i = 0; i < p.channel_size; ++i) sum_buffer[reduced_buffer_off + i] = 0.f; index = global_off; for (int c = 0; c < p.channels; ++c) { for (int i = 0; i < p.channel_size; ++i) { float exp_val = exp(input_buffer[index] - max_buffer[reduced_buffer_off + i]); output_buffer[index] = exp_val; sum_buffer[reduced_buffer_off + i] += exp_val; index++; } } // divide by computed sum index = global_off; for (int c = 0; c < p.channels; ++c) { for (int i = 0; i < p.channel_size; ++i) { float v = output_buffer[index] / sum_buffer[reduced_buffer_off + i]; if (p.logsoftmax == 1) v = log(v); output_buffer[index] = v; index++; } } }