max_pool.comp 1.68 KB
#version 450
#define LOCAL_SZ_X 256
layout(push_constant) uniform pushBlock {
      int channels;
      int in_h;
      int in_w;
      int out_h;
      int out_w;
      int padding_h;
      int padding_w;
      int filter_h;
      int filter_w;
      int stride_h;
      int stride_w;
      int total;
      int need_mask;
} p;

layout(binding = 0) readonly buffer Input0{
    float in_buffer[];
};

layout(binding = 1) writeonly buffer Output{
    float out_buffer[];
};

layout(binding = 2) writeonly buffer Mask{
    float mask_buffer[];
};

layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in;

void main()
{
  int global_size = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);
  int gid = int(gl_GlobalInvocationID.x);
  for (int index = gid; index < p.total; index += global_size)
  {
    const int pw = index % p.out_w;
    const int ph = (index / p.out_w) % p.out_h;
    const int c = (index / p.out_w / p.out_h) % p.channels;
    const int n = index / p.out_w / p.out_h / p.channels;
    int hstart = ph * p.stride_h - p.padding_h;
    int wstart = pw * p.stride_w - p.padding_w;
    const int hend = min(hstart + p.filter_h, p.in_h);
    const int wend = min(wstart + p.filter_w, p.in_w);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    float maxval = -1./0.;
    int maxidx = -1;
    int off = (n * p.channels + c) * p.in_h * p.in_w;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        if (in_buffer[off + h * p.in_w + w] > maxval) {
          maxidx = h * p.in_w + w;
          maxval = in_buffer[off + maxidx];
        }
      }
    }
    out_buffer[index] = maxval;
    if (p.need_mask == 1)
        mask_buffer[index] = maxidx;
  }
}