当前位置: 代码迷 >> 综合 >> Torchvision.ops.batched_nms() 和 nms()区别
  详细解决方案

Torchvision.ops.batched_nms() 和 nms()区别

热度:67   发布时间:2024-01-05 06:23:25.0

区别:

batched_nms():

        根据每个类别进行过滤,只对同一种类别进行计算IOU和阈值过滤。

nms():

        不区分类别对所有bbox进行过滤。如果有不同类别的bbox重叠的话会导致被过滤掉并不会分开计算。

Torchvision.ops.nms():

参数:

boxes: Tensor, 预测框
scores: Tensor, 预测置信度
iou_threshold: float, IOU阈值

Torchvision.ops.batched_nms():

参数:

boxes: Tensor, 预测框
scores: Tensor, 预测置信度
idxs: Tensor, 预测框类别
iou_threshold: float, IOU阈值

代码测试:

import torchvision.ops as ops
import torchb = torch.Tensor([[2,2,4,4], [1,1,5,5], [3,3,3.5,3.9]]) # bbox
c = torch.Tensor([0,1,0]) # classes
s = torch.Tensor([0.8,0.8,0.8]) # scoresops.batched_nms(b, s, c, 0.001)
#运行结果 tensor([1, 2])
#[1,1,5,5], [3,3,3.5,3.9] bbox实际上是有包含关系的,但是类别不一样ops.nms(b, s, 0.001)
# 运行结果 tensor([0])
# 可以看到 [1,1,5,5] 类别为1 但是被过滤掉了,只留下0号类别的[2,2,4,4]