softmax.comp 2.13 KB

#version 450
#define LOCAL_SZ_X 256

layout(binding = 0) readonly buffer buf0{
    float input_buffer[]; // outer_size * channels * channel_size
};
layout(binding = 1) buffer buf1{
    float max_buffer[]; // outer_size * channel_size
};
layout(binding = 2) buffer buf2{
    float sum_buffer[]; // outer_size * channel_size
};
layout(binding = 3) buffer buf3{
    float output_buffer[]; // outer_size * channels * channel_size
};
layout(push_constant) uniform pushBlock {
    int channel_size;
    int outer_size;
    int channels;
    int logsoftmax;
} p;
layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in;

void main()
{
    int gid = int(gl_GlobalInvocationID.x);
    if (gid >= p.outer_size) return;

    int global_off = gid * p.channels * p.channel_size;
    int reduced_buffer_off = gid * p.channel_size;

    // find the max along channel
    int index = global_off;
    for (int i = 0; i < p.channel_size; ++i)
    {
        max_buffer[reduced_buffer_off + i] = input_buffer[index];
        index++;
    }
    for (int c = 1; c < p.channels; ++c)
    {
        for (int i = 0; i < p.channel_size; ++i)
        {
            max_buffer[reduced_buffer_off + i] = max(max_buffer[reduced_buffer_off + i], input_buffer[index]);
            index++;
        }
    }

    // subtract, exp and accumulate along channel
    for (int i = 0; i < p.channel_size; ++i)
        sum_buffer[reduced_buffer_off + i] = 0.f;

    index = global_off;
    for (int c = 0; c < p.channels; ++c)
    {
        for (int i = 0; i < p.channel_size; ++i)
        {
            float exp_val = exp(input_buffer[index] - max_buffer[reduced_buffer_off + i]);
            output_buffer[index] = exp_val;
            sum_buffer[reduced_buffer_off + i] += exp_val;
            index++;
        }
    }

    // divide by computed sum
    index = global_off;
    for (int c = 0; c < p.channels; ++c)
    {
        for (int i = 0; i < p.channel_size; ++i)
        {
            float v = output_buffer[index] / sum_buffer[reduced_buffer_off + i];
            if (p.logsoftmax == 1)
                v = log(v);
            output_buffer[index] = v;
            index++;
        }
    }
}