当前位置: 代码迷 >> 综合 >> mx.metric.CrossEntropy()里面的坑
  详细解决方案

mx.metric.CrossEntropy()里面的坑

热度:35   发布时间:2023-12-15 16:23:49.0

默认加了一个非常小的常数,为了防止求对数的时候,真数部分为0,一般输入时不受影响。但当输入非常小的时候,输出变化非常大。

结论是:在这种情况下,不要使用内置API进行运算,用手写的。

以下两段代码在一般输入(概率值不是极其小)时等效。

from mxnet import nd
loss = nd.mean(-nd.pick(prob, label).log())

ce = mx.metric.CrossEntropy()
ce.update(global_label, prob_softmax)
loss = ce.get()[1]

在输入极其小时,不等效。

原因:(查看mx.metric.CrossEntropy()源码)

class CrossEntropy(EvalMetric):"""Computes Cross Entropy loss.The cross entropy over a batch of sample size :math:`N` is given by.. math::-\\sum_{n=1}^{N}\\sum_{k=1}^{K}t_{nk}\\log (y_{nk}),where :math:`t_{nk}=1` if and only if sample :math:`n` belongs to class :math:`k`.:math:`y_{nk}` denotes the probability of sample :math:`n` belonging toclass :math:`k`.Parameters----------eps : floatCross Entropy loss is undefined for predicted value is 0 or 1,so predicted values are added with the small constant.# 防止真数部分为0或1name : strName of this metric instance for display.output_names : list of str, or NoneName of predictions that should be used when updating with update_dict.By default include all predictions.label_names : list of str, or NoneName of labels that should be used when updating with update_dict.By default include all labels.Examples-------->>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]>>> labels = [mx.nd.array([0, 1, 1])]>>> ce = mx.metric.CrossEntropy()>>> ce.update(labels, predicts)>>> print ce.get()('cross-entropy', 0.57159948348999023)"""def __init__(self, eps=1e-12, name='cross-entropy',output_names=None, label_names=None):super(CrossEntropy, self).__init__(name, eps=eps,output_names=output_names, label_names=label_names,has_global_stats=True)self.eps = epsdef update(self, labels, preds):"""Updates the internal evaluation result.Parameters----------labels : list of `NDArray`The labels of the data.preds : list of `NDArray`Predicted values."""labels, preds = check_label_shapes(labels, preds, True)for label, pred in zip(labels, preds):label = label.asnumpy()pred = pred.asnumpy()label = label.ravel()assert label.shape[0] == pred.shape[0]prob = pred[numpy.arange(label.shape[0]), numpy.int64(label)]cross_entropy = (-numpy.log(prob + self.eps)).sum() # 这里的 self.eps = 1e-12self.sum_metric += cross_entropyself.global_sum_metric += cross_entropyself.num_inst += label.shape[0]self.global_num_inst += label.shape[0]

简单理解:

>>> np.log(1.11111)
0.10535951565732633
>>> np.log(1.11111+1e-12)
0.10535951565822642
>>> np.log(1e-43)
-99.01115899874397
>>> np.log(1e-43+1e-12)
-27.631021115928547

可以看到在输入特别小的情况下,输出变化很大

  相关解决方案