lrn.comp 2.13 KB
#version 450
#define LOCAL_SZ_X 256
layout(push_constant) uniform pushBlock {
    int thread_num;
    int channels;
    int height;
    int width;
    int filter_len;
    int radius;
    float alpha;
    float bias;
    float negative_beta;
} p;

layout(binding = 0) readonly buffer Input0{
    float in_buffer[];
};
layout(binding = 1) writeonly buffer Output{
    float dst_buffer[];
};
layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in;
void main()
{
  int gid = int(gl_GlobalInvocationID.x);
  int gsz = int(gl_NumWorkGroups.x * gl_WorkGroupSize.x);
  for (int index = gid; index < p.thread_num; index += gsz)
  {
    int x = index % p.width;
    int y = (index / p.width) % p.height;
    int b = index / (p.width * p.height);
    int offset = b * p.channels * p.height * p.width + y * p.width + x;
    int channel_off = p.height * p.width;
    float scale_val;
    int head = 0;
    float accum_scale = 0.0f;
    int min_val = p.radius < p.channels ? p.radius : p.channels;
    while (head < min_val) {
      accum_scale += in_buffer[offset + head * channel_off] * in_buffer[offset + head * channel_off];
      ++head;
    }
    while (head < p.channels) {
      accum_scale += in_buffer[offset + head * channel_off] * in_buffer[offset + head * channel_off];
      if (head - p.filter_len >= 0) {
        accum_scale -= in_buffer[offset + (head - p.filter_len) * channel_off]
            * in_buffer[offset + (head - p.filter_len) * channel_off];
      }
      scale_val = p.bias + accum_scale * p.alpha;
      dst_buffer[offset + (head - p.radius) * channel_off] = in_buffer[offset + (head - p.radius) * channel_off] * pow(scale_val, p.negative_beta);
      ++head;
    }
    int pos = head - min_val;
    while (pos >= 0 && pos < p.channels) {
      if (head - p.filter_len >= 0) {
        accum_scale -= in_buffer[offset + (head - p.filter_len) * channel_off]
            * in_buffer[offset + (head - p.filter_len) * channel_off];
      }
      scale_val = p.bias + accum_scale * p.alpha;
      dst_buffer[offset + pos * channel_off] = in_buffer[offset + pos * channel_off] * pow(scale_val, p.negative_beta);
      ++head;
      ++pos;
    }
  }
}