论文名称:《CornerNet: Detecting Objects as Paired Keypoints》
论文链接: https://arxiv.org/abs/1808.01244
参考代码: https://github.com/princeton-vl/CornerNet && https://github.com/princeton-vl/CornerNet-Lite
写在前面
CornerNet是ECCV2018的论文,借鉴了人体关键点检测的思路做检测,是一种典型的anchor-free方法。它算是开创了基于点代替检测框的思路。文章比较早了,这里仅作为个人笔记,记录下自己有需要注意的地方。
目录
写在前面
整体
优点
缺点
框架
高斯核惩罚衰减
原理
代码实现
Corner Pooling
原理
代码
其他
参考
整体
优点
-
不需要用到anchor
-
定位一个检测框&anchor的中心要考虑四边,而定位一个角点只要两边,信息不够还能通过Corner Pooling获取到关于角点的先验知识
-
效率更高,角点最多就是wh个,而检测框&anchor可能有wh^2个
-
不需要多尺度,FPN!
-
提出了用高斯核衰减给角点周围的负样本赋予不同的权重的方法,后面CenterNet、TTFNet都集成了这个用法
缺点
-
后处理如CornerPooling等较繁琐
- 一些紧挨的物体的检测会比较混乱
框架
前面是hourglass的特征提取网络,预测模型则针对左上角点和右下角点,分别预测一对heatmap,一对embedding和一对offset。
高斯核惩罚衰减
原理
对不同负样本点的损失函数采取不同权重值,通过高斯核来实现
原因:对于每个顶点,只有一个ground truth,其他位置都是负样本。红色实线框是ground truth,绿色虚线是一个预测框,可以看出这个预测框的两个角点和ground truth并不重合,但是该预测框基本框住了目标,因此是有用的预测框,所以要有一定权重的损失返回,这就是为什么要对不同负样本点的损失函数采取不同权重值的原因。具体来讲是这样的:在训练过程,模型减少负样本,在每个ground-truth顶点设定半径r区域内都是正样本,这是因为落在半径r区域内的顶点依然可以生成有效的边界定位框,橘色圆圈就是根据ground truth的左上角顶点、右下角顶点和设定的半径值画出来的,半径是根据圆圈内的角点组成的框和ground truth的IOU值大于0.7而设定的,圆圈内的点的数值是以圆心往外呈二维的高斯分布exp(-(x^2+y^2)/2σ^2),σ=1/3设置的,其中,中心坐标是标注的角点定位。
原文:For each corner, there is one ground-truth positive location, and all other locations are negative. During training, instead of equally penalizing negative locations, we reduce the penalty given to negative locations within a radius of the positive location. This is because a pair of false corner detections, if they are close to their respective ground truth locations, can still produce a box that sufficiently overlaps the ground-truth box (Fig. 5). We determine the radius by the size of an object by ensuring that a pair of points within the radius would generate a bounding box with at least t IoU with the ground-truth annotation (we set t to 0:3 in all experiments). Given the radius, the amount of penalty reduction is given by an unnormalized 2D Gaussian, e-x2+y22σ2 , whose center is at the positive location and whose σ is 1/3 of the radius.
代码实现
(1)
if gaussian_bump: # 使用二维高斯给出惩罚减少量width = detection[2] - detection[0]height = detection[3] - detection[1]width = math.ceil(width * width_ratio) # math.ceil, 向上取整height = math.ceil(height * height_ratio)if gaussian_rad == -1: # 通过与gt box的overlap计算gaussian radiusradius = gaussian_radius((height, width), gaussian_iou)radius = max(0, int(radius))else: # 使用提前人为设定的gaussian radiusradius = gaussian_rad# gt heatmapdraw_gaussian(tl_heatmaps[b_ind, category], [xtl, ytl], radius) # 以[xtl, ytl]为center, radius为半径绘制二维gaussian mapdraw_gaussian(br_heatmaps[b_ind, category], [xbr, ybr], radius) # 以[xbr, ybr]为center, radius为半径绘制二维gaussian mapelse: # 不使用二维高斯给出惩罚减少量,gt corner设为1,其余点均为0tl_heatmaps[b_ind, category, ytl, xtl] = 1br_heatmaps[b_ind, category, ybr, xbr] = 1
(2)
def gaussian2D(shape, sigma=1):m, n = [(ss - 1.) / 2. for ss in shape]y, x = np.ogrid[-m:m+1,-n:n+1] # y:[[-m], [-m+1], ..., [m]].Transpose, x:[[-n, -n+1, ..., n]]h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) # 2D gaussianh[h < np.finfo(h.dtype).eps * h.max()] = 0 # np.finfo(h.dtype).eps, finfo函数获取h.dtype类型信息,eps是取非负的最小值。return hdef draw_gaussian(heatmap, center, radius, k=1): # 参数:heatmaps[b_ind, category], [x, y], gaussian radiusdiameter = 2 * radius + 1 # 直径gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) # sigma设置为diameter/6,而非论文章中所说的1/3x, y = centerheight, width = heatmap.shape[0:2]left, right = min(x, radius), min(width - x, radius + 1) # 范围限制一下,以防超出heatmap大小,越界top, bottom = min(y, radius), min(height - y, radius + 1)masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] # 一般地, [y-radius, y+radius+1], [x-radius, x+radius+1]masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] # 一般地, [0, 2*radius+1], [0, 2*radius+1]np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)def gaussian_radius(det_size, min_overlap): # 计算gaussian radius。参数: (height, width), gaussian_iouheight, width = det_sizea1 = 1b1 = (height + width)c1 = width * height * (1 - min_overlap) / (1 + min_overlap)sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)r1 = (b1 - sq1) / (2 * a1)a2 = 4b2 = 2 * (height + width)c2 = (1 - min_overlap) * width * heightsq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)r2 = (b2 - sq2) / (2 * a2)a3 = 4 * min_overlapb3 = -2 * min_overlap * (height + width)c3 = (min_overlap - 1) * width * heightsq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)r3 = (b3 + sq3) / (2 * a3)return min(r1, r2, r3)
Corner Pooling
原理
corner pooling是一种特定的pooling方法,可以让预测的结果中具有一定的角点的先验知识,从而用来改善预测的角点的位置准确性。
简单来说,就是检测框的角点不一定(经常不)包含有用的信息,如上图中的两个左上角点,需要从图像的右边和下面看过去,做类似max pooling的操作:如果有比它更大的点,就替换它的值,分别得到水平和数值两张feature map,再相加得到最后的feature map。
代码
参考【6】
std::vector<at::Tensor> pool_forward(at::Tensor input
) {// Initialize output output的形状跟input是一致的,所以先根据input构建出outputat::Tensor output = at::zeros_like(input);// Get width 拿到长度int64_t width = input.size(3);// Copy the last column,left pooling是一行,从右往左进行的,所以最后一个的input的值和output的值是一致的,下面三行代码就是实现复制的代码。at::Tensor input_temp = input.select(3, width - 1);at::Tensor output_temp = output.select(3, width - 1);output_temp.copy_(input_temp);// 接下来就是从倒数第二个开始,逐个比较,永远把最大的放到output当前的位置上。at::Tensor max_temp;for (int64_t ind = 1; ind < width; ++ind) {input_temp = input.select(3, width - ind - 1); output_temp = output.select(3, width - ind); max_temp = output.select(3, width - ind - 1);at::max_out(max_temp, input_temp, output_temp);}return {output};
}
其他
-
offset表示在取整计算时丢失的精度信息,用来修正预测的角点位置,使用的损失是smooth L1 Loss,只在GT角点处计算。
-
Embedding层和对应的损失函数是用来将左上角点和右下角点进行归类匹配,即判断 if a top-left corner and a bottom-right corner belong to the same bounding box。损失函数为L_pull&L_psh,也只在GT角点处计算,参考的论文是<Pixels to graphs by associative embedding>。
-
在hourglass的网络上,使用了resnet的残差结构作为基础进行魔改
-
后续作者还做了一版CornerNet-Lite,针对推理和训练速度又做了进一步的优化改进
参考
【1】CornerNet算法解读 - 逍遥王可爱的文章 - 知乎 https://zhuanlan.zhihu.com/p/53407590
【2】CornerNet 和 CornerNet-Lite - 张佳程的文章 - 知乎 https://zhuanlan.zhihu.com/p/73422357
【3】CornerNet-Lite源码学习(二) - 张佳程的文章 - 知乎 https://zhuanlan.zhihu.com/p/84646599
【4】说点Cornernet/Centernet代码里面GT heatmap里面如何应用高斯散射核 - Monstarrrr的文章 - 知乎
https://zhuanlan.zhihu.com/p/96856635
【5】Centernet相关---尤其有关heatmap相关解释 - Big Fish的文章 - 知乎 https://zhuanlan.zhihu.com/p/85194783
【6】CornerNet,CenterNet关键代码解读: kp,_decode,left pooling
https://blog.csdn.net/Chunfengyanyulove/article/details/94646724