关于深度学习:PyTorch实现非极大值抑制NMS

34次阅读

共计 5050 个字符,预计需要花费 13 分钟才能阅读完成。

NMS 即 non maximum suppression 即非极大克制,顾名思义就是克制不是极大值的元素,搜寻部分的极大值。在最近几年常见的物体检测算法(包含 rcnn、sppnet、fast-rcnn、faster-rcnn 等)中,最终都会从一张图片中找出很多个可能是物体的矩形框,而后为每个矩形框为做类别分类概率。本文来通过 Pytorch 实现 NMS 算法。

如果你在做计算机视觉(特地是指标检测),你必定会据说过非极大值克制(nms)。网上有很多不错的文章给出了适当的概述。简而言之,非最大克制应用一些启发式办法缩小了输入边界框的数量,例如穿插除以并集(iou)。

在 PyTorch 的文档中说:NMS 迭代地删除与另一个(得分较高)框的 IoU 大于 iou_threshold 的得分较低的框。

为了钻研其如何工作,让咱们加载一个图像并创立边界框

 from PIL import Image
 import torch
 import matplotlib.pyplot as plt
 import numpy as np
 
 # credit https://i0.wp.com/craffic.co.in/wp-content/uploads/2021/02/ai-remastered-rick-astley-never-gonna-give-you-up.jpg?w=1600&ssl=1
 img = Image.open("./samples/never-gonna-give-you-up.webp")
 img

咱们手动创立 两个框,一个人脸,一个话筒

 original_bboxes = torch.tensor([
     # head
     [565, 73, 862, 373],
     # mic
     [807, 309, 865, 434]
 ]).float()
 
 w, h = img.size
 # we need them in range [0, 1]
 original_bboxes[...,0] /= h
 original_bboxes[...,1] /= w
 original_bboxes[...,2] /= h
 original_bboxes[...,3] /= w

这些 bboxes 都是在 [0,1] 范畴内的,尽管这不是必须的,但当有多个类时,这是十分有用的(咱们稍后将看到为什么)。

 from torchvision.utils import draw_bounding_boxes
 from torchvision.transforms.functional import to_tensor
 from typing import List
 
 def plot_bboxes(img : Image.Image, bboxes: torch.Tensor, *args, **kwargs) -> plt.Figure:
     w, h = img.size
     # from [0, 1] to image size
     bboxes = bboxes.clone()
     bboxes[...,0] *= h
     bboxes[...,1] *= w
     bboxes[...,2] *= h
     bboxes[...,3] *= w
     fig = plt.figure()
     img_with_bboxes = draw_bounding_boxes((to_tensor(img) * 255).to(torch.uint8), bboxes, *args, **kwargs, width=4)
     return plt.imshow(img_with_bboxes.permute(1,2,0).numpy())
 
 plot_bboxes(img, original_bboxes, labels=["head", "mic"])

为了阐明,咱们增加一些重叠的框

 max_bboxes = 3
 scaling = torch.tensor([1, .96, .97, 1.02])
 shifting = torch.tensor([0, 0.001, 0.002, -0.002])
 
 # broadcasting magic (2, 1, 4) * (1, 3, 1)
 bboxes = (original_bboxes[:,None,:] * scaling[..., None] + shifting[..., None]).view(-1, 4)
 
 plot_bboxes(img, bboxes, colors=[*["yellow"] * 4, *["blue"] * 4], labels=[*["head"] * 4, *["mic"] * 4])

当初能够看到,有 6 个 bboxes,这里咱们还须要定义一个分数,这通常由模型输入。

 scores = torch.tensor([
     0.98, 0.85, 0.5, 0.2, # for head
     1, 0.92, 0.3, 0.1 # for mic
 ])

咱们标签的分类,0 代表人脸,1 代表麦克风

 labels = torch.tensor([0,0,0,0,1,1,1,1])

最初,让咱们排列一下这些数据

 perm = torch.randperm(scores.shape[0])
 bboxes = bboxes[perm]
 scores = scores[perm]
 labels = labels[perm]

让咱们看看后果

 plot_bboxes(img, bboxes, 
             colors=["yellow" if el.item() == 0 else "blue" for el in labels], 
             labels=["head" if el.item()  == 0 else "mic" for el in labels]
            )

好了,这样咱们模仿了模型的输入了,上面进入正题。

NMS 是通过迭代删除低分数重叠的边界框来工作的。步骤如下。

bboxes are sorted by score in decreasing order
init a vector keep with ones
for i in len(bboxes):
    # was suppressed
    if keep[i] == 0:
        continue
    # compare with all the others
    for j in len(bbox):
        if keep[j]:
            if (iou(bboxes[i], bboxes[j]) > iou_threshold):
                keep[j] = 0

return keep

咱们的 Pytorch 实现,采纳三个参数(这实际上是从 pytorch 的文档中复制和粘贴的):

  • box (Tensor[N, 4])) – 用于执行 NMS 的框。它们应该是 (x1, y1, x2, y2) 格局,0 <= x1 < x2 和 0 <= y1 < y2。
  • score (Tensor[N]) – 每个 box 的得分
  • iou_threshold (float) – 抛弃所有 IoU > iou_threshold 的框
  • 返回值是非克制边界框的索引
from torchvision.ops.boxes import box_iou

def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
    order = torch.argsort(-scores)
    indices = torch.arange(bboxes.shape[0])
    keep = torch.ones_like(indices, dtype=torch.bool)
    for i in indices:
        if keep[i]:
            bbox = bboxes[order[i]]
            iou = box_iou(bbox[None,...],(bboxes[order[i + 1:]]) * keep[i + 1:][...,None])
            overlapped = torch.nonzero(iou > iou_threshold)
            keep[overlapped + i + 1] = 0
    return order[keep]

让咱们具体阐明下这个参数:

order = scores.argsort()

依据分数失去排序的指标

indices = torch.arange(bboxes.shape[0])

创立用于迭代 bboxes 的索引 indices

keep = torch.ones_like(indices, dtype=torch.bool)

keep 是用于判断一个 bbox 是否应该保留的向量,如果 Keep [i] == 1,则 bboxes[order[i]]不被克制

for i in indices:
    ...

for 循环遍历所有的 box, 如果以后 box 未被克制,则 keep[i] = 1

bbox = bboxes[order[i]]]

来通过已排序的地位获取 bbox

iou = box_iou(bbox[None,...], (bboxes[order[i + 1:]]) * keep[i + 1:][...,None])

计算以后 bbox 和所有其余候选 bbox 之间的 iou。这将把所有克制框设置为零(因为 keep 将等于 0)

(bboxes ...)[order[i + 1:]]

在排序的程序中与前面所有的框进行比拟,因为须要跳过以后的框,所以这里是 i + 1,

overlapped = torch.nonzero(iou > iou_threshold)
keep[overlapped + i + 1] = 0

计算和抉择 iou 大于 iou_threshold 的索引。

咱们之前对 bboxes 进行了切片,(bboxes…)[i + 1:]),所以咱们须要增加这些索引的偏移量,这就是前面 + i + 1 的起因。

最初返回 order[keep],这样映射回原始的 box 索引(未排序),这样一个简略的函数就执行实现了。

让咱们看看后果

nms_indices = nms(bboxes, scores, .45)
plot_bboxes(img, 
            bboxes[nms_indices],
            colors=["yellow" if el.item() == 0 else "blue" for el in labels[nms_indices]], 
            labels=["head" if el.item()  == 0 else "mic" for el in labels[nms_indices]]
           )

因为有多个类,所以须要让 nms 在同一个类中计算 iou。还记得下面咱们提到的在 [0,1] 之间吗? 能够给它们增加标签,把不同类的框辨别开。

nms_indices = nms(bboxes + labels[..., None], scores, .45)
plot_bboxes(img, 
            bboxes[nms_indices],
            colors=["yellow" if el.item() == 0 else "blue" for el in labels[nms_indices]], 
            labels=["head" if el.item()  == 0 else "mic" for el in labels[nms_indices]]
           )

如果咱们将阈值更改为 0.1,就失去了下图

让咱们比照下 pytorch 官网的实现:

from torchvision.ops.boxes import nms as torch_nms
nms_indices = torch_nms(bboxes + labels[..., None], scores, .45)
plot_bboxes(img, 
            bboxes[nms_indices],
            colors=["yellow" if el.item() == 0 else "blue" for el in labels[nms_indices]], 
            labels=["head" if el.item()  == 0 else "mic" for el in labels[nms_indices]]
           )

后果是一样的。然咱们看看工夫:

%%timeit
nms(bboxes + labels[..., None], scores, .45)
#534 µs ± 22.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%%timeit
torch_nms(bboxes + labels[..., None], scores, .45)
#54.4 µs ± 3.29 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

咱们的实现慢了大概 10 倍,哈,这个后果很失常,因为咱们咱们没有应用自定义的 cpp 内核! 然而这并不代表咱们的实现没有用,因为手写代码咱们齐全理解了 NMS 的工作原理,这是本文的真正意义,总之在这篇文章中咱们看到了如何在 PyTorch 中实现非最大克制,这对你理解指标检测的相干常识是十分有帮忙的。

https://avoid.overfit.cn/post/1ffeb08f8ea4494cb992b0ad05db174b

作者:Francesco Zuppichini

正文完
 0