balance binary cross entropy损失函数在分割任务中很有用,因为分割任务会遇到正负样本不均的问题,甚至在边缘的分割任务重,样本不均衡达到了很高的比例。我们先来了解原理,再了解具体如何编程。
原理
比如一个预测结果,记作P∈RH×WP \in R^{H \times W}P∈RH×W,对应的label是R,尺寸一样。
R中的1和0,即正负样本比例很不协调。我们想给这两个类别一个权重系数,乘在每一个像素计算的loss上。
这个系数的算法是:
apos=numneg/(H×W)a_{pos} = num_{neg} / (H \times W)apos?=numneg?/(H×W)
aneg=numpos/(H×W)a_{neg} = num_{pos} / (H \times W)aneg?=numpos?/(H×W)
正样本的系数是通过负例的数目占总数目的比例得到,负样本的系数是正例的数目占总数目的比例得到。
至于为啥计算正样本的系数需要用负样本的比例,那是因为正样本数目少,就给一个大的比例,增大一下正样本的梯度,不至于负样本的梯度占统治地位(dominating),避免网络倾向于把样本判断为负样本。
代码
def bce2d(pred, gt, reduction='mean'):pos = torch.eq(gt, 1).float()neg= torch.eq(gt, 0).float()num_pos = torch.sum(pos)num_neg = torch.sum(neg)num_total = num_pos + num_negalpha_pos = num_neg / num_totalalpha_neg = num_pos / num_totalweights = alpha_pos * pos + alpha_neg * negreturn F.binary_cross_entropy_with_logits(pred, target, weights, reduction = reduction)
Note
有些情况下,正样本实在太少,负样本是在太多。这样情况下,负样本的权重系数就会接近0,使得训练出来的网络仍然是biased。我们可以通过在负样本系数上乘以一个大于1的值。
aneg=1.1×numpos/(H×W)a_{neg} = 1.1 \times num_{pos} / (H \times W)aneg?=1.1×numpos?/(H×W)