1 理论
OHEM就是对每次检测到loss较大的前几个样本进行计算loss,重新训练。
2 实现
def rpn_class_loss_graph(config, rpn_match, rpn_class_logits):"""RPN anchor classifier loss.rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,-1=negative, 0=neutral anchor.rpn_class_logits: [batch, anchors, 2]. RPN classifier logits for FG/BG.modified by YY: to implement OHEM"""# Squeeze last dim to simplifyrpn_match = tf.squeeze(rpn_match, -1)# Get anchor classes. Convert the -1/+1 match to 0/1 values.anchor_class = K.cast(K.equal(rpn_match, 1), tf.int32)# Positive and Negative anchors contribute to the loss,# but neutral anchors (match value = 0) don't.indices = tf.where(K.not_equal(rpn_match, 0))# Pick rows that contribute to the loss and filter out the rest.rpn_class_logits = tf.gather_nd(rpn_class_logits, indices)anchor_class = tf.gather_nd(anchor_class, indices)# Cross entropy lossce_loss = K.sparse_categorical_crossentropy(target=anchor_class,output=rpn_class_logits,from_logits=True)n_selected = tf.cast(config.RPN_TRAIN_ANCHORS_PER_IMAGE, tf.int32)vals, _ = tf.nn.top_k(ce_loss, k = n_selected)thresh = vals[-1]samples = ce_loss >= thresh# only include samples in loss calloss_weight = tf.cast(samples, tf.float32)loss = K.sum(ce_loss * loss_weight) / K.sum(loss_weight)#loss = K.switch(tf.size(loss) > 0, K.sum(loss * loss_weight) / K.sum(loss_weight), tf.constant(0.0))#loss = K.switch(tf.size(loss) > 0, K.mean(loss), tf.constant(0.0))return loss
其他实例
class topk_crossEntrophy(nn.Module):def __init__(self, top_k=0.7):super(topk_crossEntrophy, self).__init__()self.loss = nn.NLLLoss()self.top_k = top_kself.softmax = nn.LogSoftmax()returndef forward(self, input, target):softmax_result = self.softmax(input)loss = Variable(torch.Tensor(1).zero_())for idx, row in enumerate(softmax_result):gt = target[idx]pred = torch.unsqueeze(row, 0)cost = self.loss(pred, gt)loss = torch.cat((loss, cost), 0)loss = loss[1:]if self.k == 1:valid_loss = lossindex = torch.topk(loss, int(self.top_k * loss.size()[0]))valid_loss = loss[index[1]]return torch.mean(valid_loss)
参考:
- Github ;
- 博客详解OHEM;
- pytorch 社区;
- tensorflow 实现