关于神经网络:deformableconv可变形卷积源码解析

89次阅读

共计 21932 个字符,预计需要花费 55 分钟才能阅读完成。

留神:本文源码来源于 openlab 中的 mmcv

变形卷积源码次要有三个文件:

  • deform_conv.cpp: 位于 /mmcv/ops/csrc/pytorch/deform_conv.cpp
  • deform_conv_cuda.cu: 位于 /mmcv/ops/csrc/pytorch/deform_conv_cuda.cu
  • deform_conv_cuda_kernel.cuh:位于 /mmcv/ops/csrc/deform_conv_cuda_kernel.cuh

前向流传

首先查看 deform_conv.cpp 中的函数 deform_conv_forward()

函数输出次要有特色图 input,卷积权重 weight,偏置 offset,输入特色图 output 等

函数调用了 deform_conv_forward_cuda()函数

void deform_conv_forward(Tensor input, Tensor weight, Tensor offset,
                         Tensor output, Tensor columns, Tensor ones, int kW,
                         int kH, int dW, int dH, int padW, int padH,
                         int dilationW, int dilationH, int group,
                         int deformable_group, int im2col_step) {if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
    CHECK_CUDA_INPUT(input);
    CHECK_CUDA_INPUT(offset);
    CHECK_CUDA_INPUT(weight);
    CHECK_CUDA_INPUT(output);
    CHECK_CUDA_INPUT(columns);
    CHECK_CUDA_INPUT(ones);

    deform_conv_forward_cuda(input, weight, offset, output, columns, ones, kW,
                             kH, dW, dH, padW, padH, dilationW, dilationH,
                             group, deformable_group, im2col_step);
#else
    AT_ERROR("DeformConv is not compiled with GPU support");
#endif
  } else {AT_ERROR("DeformConv is not implemented on CPU");
  }
}

进一步调用了函数 DeformConvForwardCUDAKernelLauncher(),位于文件 deform_conv_cuda.cu 中

void deform_conv_forward_cuda(Tensor input, Tensor weight, Tensor offset,
                              Tensor output, Tensor columns, Tensor ones,
                              int kW, int kH, int dW, int dH, int padW,
                              int padH, int dilationW, int dilationH, int group,
                              int deformable_group, int im2col_step) {
  DeformConvForwardCUDAKernelLauncher(
      input, weight, offset, output, columns, ones, kW, kH, dW, dH, padW, padH,
      dilationW, dilationH, group, deformable_group, im2col_step);
}

这是变形卷积外围函数,惯例卷积是特色图上每个点与邻域点进行卷积操作,而变形卷积多了一个 offset 偏置,从新获取新的邻域点特色,再进行卷积,而不是邻接的点,所以变形卷积次要分为 两个步骤

1)依据 offset 收集新的邻域特色;

2)再进行卷积。

其中第一步骤收集邻域特色绝对比拟麻烦,次要思路是依据 offset 找到新的邻域点,再拼接到中心点特色上,使得通道数由 C 变成 C *kh*kw,是由下文代码中的 deformable_im2col()函数实现。

第二步就很简略了,用一个维度为 (C2, C1, 1, 1) 卷积就能够实现。

void DeformConvForwardCUDAKernelLauncher(Tensor input, Tensor weight,  // 输出特色图 input:(B, C1, H, W), 卷积权重 weight:(C2, C1/group, kh, kw)
                                         Tensor offset, Tensor output,  // 坐标偏置 offset:(B, deform_group*2*kh*kw, h, w), 输入特色 output:(B, C2, h, w)
                                         Tensor columns, Tensor ones, int kW, // kW,Kh 为卷积核大小
                                         int kH, int dW, int dH, int padW, // dW,dH 为卷积步长 stride
                                         int padH, int dilationW, int dilationH, // 
                                         int group, int deformable_group, //  变形卷积中个别每个通道专用一个坐标偏置,也能够几个通道维度专用一个坐标偏置,那么每个通道就会分为 deformable_group 个组数
                                         int im2col_step) {  // step
  // todo: resize columns to include im2col: done
  // todo: add im2col_step as input
  // todo: add new output buffer and transpose it to output (or directly
  // transpose output) todo: possibly change data indexing because of
  // parallel_imgs

  deform_conv_shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH,
                          padW, dilationH, dilationW, group, deformable_group);
  at::DeviceGuard guard(input.device());

  int batch = 1;
  if (input.ndimension() == 3) {
    // Force batch
    batch = 0;
    input.unsqueeze_(0);
    offset.unsqueeze_(0);
  }

  // todo: assert batchsize dividable by im2col_step

  long batchSize = input.size(0); // B
  long nInputPlane = input.size(1); // 输出通道数 C1
  long inputHeight = input.size(2); // 输出高 H
  long inputWidth = input.size(3); // 输出宽 W

  long nOutputPlane = weight.size(0); // 输入通道数 C2

  long outputWidth =
      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; // 输入宽 w
  long outputHeight =
      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; // 输出 h

  TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");

  output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
                        outputHeight, outputWidth}); //  (B, C2, h, w)->(B/step, step, C2, h, w)
  columns = at::zeros({nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
      input.options());  // (C1*kw*kh, step*h*w),其中 C1*kw*kh 能够看作将输出特色图每个点的邻域特色汇聚在一起,邻域个数是 kw*kh,每个邻域点的通道数都是 C1,所以总的就是 C1*kw*kh

  if (ones.ndimension() != 2 ||
      ones.size(0) * ones.size(1) < outputHeight * outputWidth) {ones = at::ones({outputHeight, outputWidth}, input.options()); // (h, w)
  }

  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
                      inputHeight, inputWidth}); // (B, C1, H, W)->(B/step, step, C1, H, W)
  offset = offset.view({batchSize / im2col_step, im2col_step,
                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});  // (B, deform_group*2*kh*kw, h, w)->(B/step, step, deform_group*2*kh*kw, h, w)

  Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane,
                                    im2col_step * outputHeight, outputWidth},
                                   output.options()); // (B/step, C2, step*h, w), 元素全为 0

  output_buffer = output_buffer.view({output_buffer.size(0), group, output_buffer.size(1) / group,
       output_buffer.size(2), output_buffer.size(3)}); // (B/step, C2, step*h, w)->(B/step, group, C2/group, step*h, w)

// 分成 B /step 个别离进行运算
  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
      // 该函数的详细分析见后文,是为了获取 columns 值, columns 能够了解为输出特色图上每个特色点依据 offset 汇总邻域点特色到本人维度上
    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,  //input[elt]:(step, C1, H, W), offset[elt]:(step, deform_group*2*kw*kh, h, w)
                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,  
                      dilationW, im2col_step, deformable_group, columns); //  columns:(C1*kw*kh, step*h*w)

    columns = columns.view({group, columns.size(0) / group, columns.size(1)});  // (C1*kw*kh, step*h*w)->(group, C1*kh*kw/group, step*h*w)
    weight = weight.view({group, weight.size(0) / group, weight.size(1), 
                          weight.size(2), weight.size(3)});  //(C2, C1/group, kh, kw)-> (group, C2/group, C1/group, kh, kw)

   // 分成 group 个别离进行运算
    for (int g = 0; g < group; g++) {
        //columns 是汇总的特色,再与 weight 进行卷积来获取输入特色 output_buffer
        //addmm_()是对外部矩阵乘积后果进行相加
        // flatten(1)对 1 维和之后维度进行展平
      output_buffer[elt][g] = output_buffer[elt][g]
                                  .flatten(1) // (C2/group, step*h, w)->(C2/group, step*h*w)
                                  .addmm_(weight[g].flatten(1), columns[g])  // (C2/group, C1/group*kh*kw) mm (C1/group*kh*kw, step*h*w)->(C2/group, step*h*w)
                                  .view_as(output_buffer[elt][g]);  // (C2/group, step*h*w)->(C2/group, step*h, w)
    }
    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)}); // (group, C1*kh*kw/group, step*h*w)->(C1*kw*kh, step*h*w) 
  }


  output_buffer = output_buffer.view({output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
       output_buffer.size(3), output_buffer.size(4)});  // (B/step, group, C2/group, step*h, w)->(B/step, C2, step*h, w)

  output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
                                      im2col_step, outputHeight, outputWidth});  //(B/step, C2, step*h, w)-> (B/step, C2, step, h, w)
  output_buffer.transpose_(1, 2);  //(B/step, C2, step, h, w)-> (B/step, step, C2, h, w)
  output.copy_(output_buffer); // 复制 output_buffer 到 output
  output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});  // (B/step, step, C2, h, w)->(B, C2, h, w)

  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); // (B/step, step, C1, H, W)->(B,  C1, H, W)
  offset = offset.view({batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});  // (B/step, step, deform_group*2*kh*kw, h, w)->(B, deform_group*2*kh*kw, h, w)

  if (batch == 0) {output = output.view({nOutputPlane, outputHeight, outputWidth});
    input = input.view({nInputPlane, inputHeight, inputWidth});
    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
  }
}

上文代码中的函数 deformable_im2col(),变形卷积前向流传重点代码,实现邻域特色的汇总

void deformable_im2col(Tensor data_im, Tensor data_offset, const int channels,// 输出特色 data_im:(step, C1, H, W), 偏置 data_offset:(step, deform_group*2*kw*kh, h, w), channels: C1
                       const int height, const int width, const int ksize_h, // height=H,  width: W,
                       const int ksize_w, const int pad_h, const int pad_w, 
                       const int stride_h, const int stride_w, 
                       const int dilation_h, const int dilation_w,
                       const int parallel_imgs, const int deformable_group, // 通道数 parallel_imgs=step
                       Tensor data_col) {//  data_col:(C1*kw*kh, step*h*w)
  // num_axes should be smaller than block size
  // todo: check parallel_imgs is correctly passed in
    // 获取 columns 的高和宽
  int height_col =
      (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; // h
  int width_col =
      (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; // w
  int num_kernels = channels * height_col * width_col * parallel_imgs; // C1*h*w*step
  int channel_per_deformable_group = channels / deformable_group; // C1/deform_group, 是指每个 group 占据多少通道数

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(data_im.scalar_type(), "deformable_im2col_gpu", ([&] {//[&]匿名函数:用到的任何内部变量都隐式按援用捕捉
        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();          // 获取指针,地址地位
        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
        scalar_t *data_col_ = data_col.data_ptr<scalar_t>();

          // 用在 cuda 上的内核函数,后文剖析
        deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels),
                                       THREADS_PER_BLOCK, 0,
                                       at::cuda::getCurrentCUDAStream()>>>(
            num_kernels, data_im_, data_offset_, height, width, ksize_h,
            ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
            channel_per_deformable_group, parallel_imgs, channels,
            deformable_group, height_col, width_col, data_col_);
      }));
  AT_CUDA_CHECK(cudaGetLastError());
}

函数 deformable_im2col_gpu_kernel()在 deform_conv_cuda_kernel.cuh 文件中

template <typename T>
__global__ void deformable_im2col_gpu_kernel(const int n, const T *data_im, const T *data_offset, const int height,  // n=C1*h*w*step,;data_im:(step, C1, H, W)的起始地址, 为输出特色,;data_offset:(step, deform_group*2*kw*kh, h, w)的起始地址,为坐标偏置;height=H
    const int width, const int kernel_h, const int kernel_w, const int pad_h, // 输出特色宽 width=W 
    const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w,
    const int channel_per_deformable_group, const int batch_size, //  batch_size=step
    const int num_channels, const int deformable_group, const int height_col, //num_channels=C1, 输入特色的高 height_col=h
    const int width_col, T *data_col) {//width_col=w,  data_col:(C1*kw*kh, step*h*w)的起始地址,为 columns
  CUDA_1D_KERNEL_LOOP(index, n) {  // index 从 0 到 n 进行遍历
    // index index of output matrix
      // index 为 0 到 C1*step*h* w 遍历, 能够了解为 columns 上特色点的索引,但比 columns 维度少了 kh*kw,所以一个 index 对应一个 kh*kw
    const int w_col = index % width_col;  // 特色点的 w 坐标
    const int h_col = (index / width_col) % height_col;   // 特色点的 h 坐标
    const int b_col = (index / width_col / height_col) % batch_size;  // 特色点所在的 batch 数
    const int c_im = (index / width_col / height_col) / batch_size; // 特色点对应输出图 imput 上的通道数,输入特色 columns 和输出特色 data_im 上的特色点具备对应关系
    const int c_col = c_im * kernel_h * kernel_w;  // 特色点对应输入特色图 columns 的通道数,因为 columns 的通道为 C1*kw*kh,而 data_im 的通道为 C1

    // compute deformable group index
    const int deformable_group_index = c_im / channel_per_deformable_group;

    const int h_in = h_col * stride_h - pad_h;    // 输出特色图上的特色点坐标映射在输出特色图上的高度值 h
    const int w_in = w_col * stride_w - pad_w; // 输出特色图上的特色点坐标映射在输出特色图上的宽度值 w
      // 取得该特色点在 columns 上的地址
    T *data_col_ptr =  
        data_col +
        ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
      //  取得该特色点在 data_im 上的对应通道的特色面上的起始地址,因为只思考了 b_col 和 c_im,还没有包含特色点具体的高和宽
    const T *data_im_ptr = 
        data_im + (b_col * num_channels + c_im) * height * width;
      // 只思考了 b_col 和 deformable_group_index, 还没有包含特色点对应的坐标和偏置
    const T *data_offset_ptr =
        data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
                          kernel_h * kernel_w * height_col * width_col;
// 因为一个 index 对应一个 kh*kw,所以要进一步遍历 kh*kw
    for (int i = 0; i < kernel_h; ++i) {for (int j = 0; j < kernel_w; ++j) {
        const int data_offset_h_ptr =
            ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;  // 失去残缺的偏置高度地址
        const int data_offset_w_ptr =
            ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; // 失去残缺的偏置宽度地址
        const T offset_h = data_offset_ptr[data_offset_h_ptr]; // 偏置高度
        const T offset_w = data_offset_ptr[data_offset_w_ptr]; // 偏置宽度
        T val = static_cast<T>(0);
        const T h_im = h_in + i * dilation_h + offset_h;    // 对应的输出特色图上该点邻域点的偏置高度值 h
        const T w_im = w_in + j * dilation_w + offset_w; // 对应的输出特色图上该点邻域点的偏置宽度值 w
        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
           // 因为 h_im 和 w_im 为小数,所以须要插值计算,后文剖析
          val = deformable_im2col_bilinear(data_im_ptr, width, height, width,
                                           h_im, w_im);
        *data_col_ptr = val;
        data_col_ptr += batch_size * height_col * width_col;  // 因为 columns 的维度为(C1*kw*kh, step*h*w),kw 和 kh 在 step*h* w 后面,所以每个偏移点两头相隔 step*h* w 个点
      }
    }
  }
}

函数 deformable_im2col_bilinear()依据偏置获取特色点,双线性插值,比较简单

template <typename T>
__device__ T deformable_im2col_bilinear(const T *input, const int data_width,  // input: 输出图以后通道特色面的起始地址
                                        const int height, const int width, T h,
                                        T w) {if (h <= -1 || height <= h || w <= -1 || width <= w) {return 0;}

    // 高低取整
  int h_low = floor(h);
  int w_low = floor(w);
  int h_high = h_low + 1;
  int w_high = w_low + 1;

  T lh = h - h_low;
  T lw = w - w_low;
  T hh = 1 - lh, hw = 1 - lw;

  T v1 = 0;
  if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
  T v2 = 0;
  if (h_low >= 0 && w_high <= width - 1)
    v2 = input[h_low * data_width + w_high];
  T v3 = 0;
  if (h_high <= height - 1 && w_low >= 0)
    v3 = input[h_high * data_width + w_low];
  T v4 = 0;
  if (h_high <= height - 1 && w_high <= width - 1)
    v4 = input[h_high * data_width + w_high];

  T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;

  T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  return val;
}

到此前向流传剖析完结,下文剖析梯度反传

梯度反传

须要对 输出特色 Input 偏置 offset以及 参数权重 weight进行梯度反传计算

输出特色 Input 和偏置 offset 的梯度计算

梯度反传是由前向流传计算的 gradOutput 进行反向计算梯度

查看 deform_conv.cpp 中的函数 deform_conv_backward_input()

该函数调用了 deform_conv_forward_cuda()函数

void deform_conv_backward_input(Tensor input,  // (B, C1, H, W)
                                Tensor offset, // 
                                Tensor gradOutput,
                                Tensor gradInput, Tensor gradOffset,
                                Tensor weight, Tensor columns, int kW, int kH,
                                int dW, int dH, int padW, int padH,
                                int dilationW, int dilationH, int group,
                                int deformable_group, int im2col_step) {if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
    CHECK_CUDA_INPUT(input);
    CHECK_CUDA_INPUT(offset);
    CHECK_CUDA_INPUT(gradOutput);
    CHECK_CUDA_INPUT(gradInput);
    CHECK_CUDA_INPUT(gradOffset);
    CHECK_CUDA_INPUT(weight);
    CHECK_CUDA_INPUT(columns);

    deform_conv_backward_input_cuda(input, offset, gradOutput, gradInput,
                                    gradOffset, weight, columns, kW, kH, dW, dH,
                                    padW, padH, dilationW, dilationH, group,
                                    deformable_group, im2col_step);
#else
    AT_ERROR("DeformConv is not compiled with GPU support");
#endif
  } else {AT_ERROR("DeformConv is not implemented on CPU");
  }
}

函数 deform_conv_backward_input_cuda()

进一步调用了函数 DeformConvBackwardInputCUDAKernelLauncher()

void deform_conv_backward_input_cuda(Tensor input, Tensor offset,
                                     Tensor gradOutput, Tensor gradInput,
                                     Tensor gradOffset, Tensor weight,
                                     Tensor columns, int kW, int kH, int dW,
                                     int dH, int padW, int padH, int dilationW,
                                     int dilationH, int group,
                                     int deformable_group, int im2col_step) {
  DeformConvBackwardInputCUDAKernelLauncher(
      input, offset, gradOutput, gradInput, gradOffset, weight, columns, kW, kH,
      dW, dH, padW, padH, dilationW, dilationH, group, deformable_group,
      im2col_step);
}

函数 DeformConvBackwardInputCUDAKernelLauncher()位于文件 deform_conv_cuda.cu 中,这是 外围代码

void DeformConvBackwardInputCUDAKernelLauncher(Tensor input,  //  (B, C1, H, W)
    Tensor offset, // (B, deform_group*2*kw*kh, h, w)
    Tensor gradOutput, // (B, C2, h, w)
    Tensor gradInput,  // (B, C1, H, W)
    Tensor gradOffset, // (B, deform_group*2*kw*kh, h, w)
    Tensor weight,  // (C2, C1/group, kh, kw)
    Tensor columns,
    int kW, int kH, int dW, 
    int dH, int padW, int padH, int dilationW, int dilationH, int group,
    int deformable_group, int im2col_step) { 
  deform_conv_shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW,
                          padH, padW, dilationH, dilationW, group,
                          deformable_group);
  at::DeviceGuard guard(input.device());

  int batch = 1;

  if (input.ndimension() == 3) {
    // Force batch
    batch = 0;
    input = input.view({1, input.size(0), input.size(1), input.size(2)});
    offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
    gradOutput = gradOutput.view({1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
  }

  long batchSize = input.size(0);  // B
  long nInputPlane = input.size(1); // C!
  long inputHeight = input.size(2);// H
  long inputWidth = input.size(3);// W

  long nOutputPlane = weight.size(0); //C2

  long outputWidth =
      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;  //h
  long outputHeight =
      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; // w

  TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});  // (B, C1, H, W)
  columns = at::zeros({nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
      input.options());  // (C1*kw*kh, step*h*w), 全为 0

  // change order of grad output
  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
                                nOutputPlane, outputHeight, outputWidth}); //(B, C2, h, w)-> (B/step, step, C2, h, w)
  gradOutput.transpose_(1, 2);  // (B/step, step, C2, h, w)->(B/step, C2, step, h, w)

  gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
                              inputHeight, inputWidth});  // (B, C1, H, W)->(B/step, step, C1, H, W)
  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
                      inputHeight, inputWidth});  // (B, C1, H, W)->(B/step, step, C1, H, W)
  gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
                                deformable_group * 2 * kH * kW, outputHeight,
                                outputWidth});  // (B, deform_group*2*kw*kh, h, w)->(B/step, step, deform_group*2*kh*kw, h, w)
  offset =
      offset.view({batchSize / im2col_step, im2col_step,
                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});  // (B, deform_group*2*kw*kh, h, w)->(B/step, step, deform_group*2*kh*kw, h, w)

  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
    // divide into groups
    columns = columns.view({group, columns.size(0) / group, columns.size(1)}); // (C1*kw*kh, step*h*w)->(group, C1*kw*kh/group, step*h*w)
    weight = weight.view({group, weight.size(0) / group, weight.size(1),
                          weight.size(2), weight.size(3)});  // (C2, C1/group, kh, kw)->(group, C2/group, C1/group, kh kw)
    gradOutput = gradOutput.view({gradOutput.size(0), group, gradOutput.size(1) / group,
         gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});  //(B/step, C2, step, h, w)-> (B/step, group, C2/group, step, h, w)

    for (int g = 0; g < group; g++) {
       // 依据 gradOutput 获取 columns 的梯度, 矩阵乘积的梯度计算
      columns[g] = columns[g].
                                     addmm_(weight[g].flatten(1).transpose(0, 1),  //(C2/group, C1/group, kh kw)-> (C2/group, C1/group*kh*kw)-> (C1/group*kh*kw, C2/group)
                                     gradOutput[elt][g].flatten(1), 0.0f, 1.0f);  // (C2/group, step, h, w)->(C2/group, step*h*w)
    }

    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)}); //(group, C1*kw*kh/group, step*h*w)-> (C1*kh*kw, step*h*w)
    gradOutput = gradOutput.view({gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
         gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); // (B/step, group, C2/group, step, h, w)->(B/step, C2, step, h, w)

      // 计算偏置的梯度 gradOffset,见后文剖析
    deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,  // columns:  (C1*kh*kw, step*h*w), input[elt]: (step, C1, H, W), offset[elt]: (step, deform_group*2*kw*kh, h, w), nInputPlane:C1
                            inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
                            dilationH, dilationW, im2col_step, deformable_group,
                            gradOffset[elt]); // gradOffset[elt]: (step, deform_group*2*kw*kh, h, w)

      // 计算输出特色的梯度 gradInput,见后文剖析
    deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,   //  columns:  (C1*kh*kw, step*h*w),  offset[elt]: (step, deform_group*2*kw*kh, h, w), nInputPlane:C1
                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
                      dilationW, im2col_step, deformable_group, gradInput[elt]);  // (step, C1, H, W)
  }

  gradOutput.transpose_(1, 2);
  gradOutput =
      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});

  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
  gradOffset = gradOffset.view({batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
  offset = offset.view({batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});

  if (batch == 0) {gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
    input = input.view({nInputPlane, inputHeight, inputWidth});
    gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
    gradOffset =
        gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
  }
}

函数 deformable_col2im_coord(), 计算偏置的梯度 gradOffset, 函数调用了 deformable_col2im_coord_gpu_kernel()内核函数

void deformable_col2im_coord(Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, // data_col:  (C1*kh*kw, step*h*w), data_im: (step, C1, H, W), data_offset: (step, deform_group*2*kw*kh, h, w), channels:C1
    const int height, const int width, const int ksize_h, const int ksize_w,  // height=H
    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w, const int parallel_imgs, // parallel_imgs: step
    const int deformable_group, Tensor grad_offset) {// grad_offset: (step, deform_group*2*kw*kh, h, w)
  int height_col =
      (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;  // h
  int width_col =
      (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; //w
  int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w *
                    deformable_group * parallel_imgs;  // h*w*2*kh*kw*deform_group*step
  int channel_per_deformable_group =
      channels * ksize_h * ksize_w / deformable_group;  // C1*kh*kw/deform_group

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
        scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();

        deformable_col2im_coord_gpu_kernel<<<
            GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0,
            at::cuda::getCurrentCUDAStream()>>>(
            num_kernels, data_col_, data_im_, data_offset_, channels, height,
            width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
            dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs,
            2 * ksize_h * ksize_w * deformable_group, deformable_group,
            height_col, width_col, grad_offset_);
      }));
  AT_CUDA_CHECK(cudaGetLastError());
}

deformable_col2im_coord_gpu_kernel()内核函数

template <typename T>
__global__ void deformable_col2im_coord_gpu_kernel(
    const int n,  // h*w*2*kh*kw*deform_group*step
    const T *data_col,  // (C1*kh*kw, step*h*w)
    const T *data_im, // (step, C1, H, W)
    const T *data_offset,  //  (step, deform_group*2*kw*kh, h, w), 
    const int channels, const int height, const int width, const int kernel_h, //channels: C1
    const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
    const int stride_w, const int dilation_h, const int dilation_w,
    const int channel_per_deformable_group, const int batch_size,  // batch_size=step
    const int offset_channels,  //  deform_group*2*kw*kh
    const int deformable_group, const int height_col,  
    const int width_col, T *grad_offset) {//grad_offset: (step, deform_group*2*kw*kh, h, w)
    // index 从 0 到 h *w*2*kh*kw*deform_group*step 进行遍历,也就是遍历整个 offset
  CUDA_1D_KERNEL_LOOP(index, n) {
    T val = 0;
    int w = index % width_col;
    int h = (index / width_col) % height_col;
    int c = (index / width_col / height_col) % offset_channels;
    int b = (index / width_col / height_col) / offset_channels;
    // compute the start and end of the output

    const int deformable_group_index = c / (2 * kernel_h * kernel_w);  //  c/(2*kh*kw) 示意第几个 group
    const int col_step = kernel_h * kernel_w;
    int cnt = 0;
    const T *data_col_ptr = data_col + deformable_group_index *
                                           channel_per_deformable_group *
                                           batch_size * width_col * height_col;
    const T *data_im_ptr =
        data_im + (b * deformable_group + deformable_group_index) *
                      channel_per_deformable_group / kernel_h / kernel_w *
                      height * width;
    const T *data_offset_ptr =
        data_offset + (b * deformable_group + deformable_group_index) * 2 *
                          kernel_h * kernel_w * height_col * width_col;

    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;  // c

    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
         col_c += col_step) {
      const int col_pos =
          (((col_c * batch_size + b) * height_col) + h) * width_col + w;
      const int bp_dir = offset_c % 2;

      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
      int i =
          (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
      int w_out = col_pos % width_col;
      int h_out = (col_pos / width_col) % height_col;
      int w_in = w_out * stride_w - pad_w;
      int h_in = h_out * stride_h - pad_h;
      const int data_offset_h_ptr =
          (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
      const int data_offset_w_ptr =
          (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
           w_out);
      const T offset_h = data_offset_ptr[data_offset_h_ptr];
      const T offset_w = data_offset_ptr[data_offset_w_ptr];
      T inv_h = h_in + i * dilation_h + offset_h;
      T inv_w = w_in + j * dilation_w + offset_w;
      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
        inv_h = inv_w = -2;
      const T weight = get_coordinate_weight(inv_h, inv_w, height, width,
                                             data_im_ptr + cnt * height * width,
                                             width, bp_dir);
      val += weight * data_col_ptr[col_pos];
      cnt += 1;
    }

    grad_offset[index] = val;
  }
}

正文完
 0