当前位置: 代码迷 >> 综合 >> TensorFlow 中计算给定 logits 的 sigmoid 交叉熵 tf.sigmoid_cross_entropy_with_logits 的基本用法及实例代码
  详细解决方案

TensorFlow 中计算给定 logits 的 sigmoid 交叉熵 tf.sigmoid_cross_entropy_with_logits 的基本用法及实例代码

热度:74   发布时间:2024-01-14 07:01:37.0

一、环境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

cudnn64_7.dll

Python 3.6.3

Windows 10

 

二、官方说明

计算输入张量 logits 的 sigmoid 交叉熵

https://tensorflow.google.cn/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits

tf.nn.sigmoid_cross_entropy_with_logits(_sentinel=None,labels=None,logits=None,name=None
)

计算离散型分类任务中的概率误差,其中每个类别都是独立的但不是互斥的。

可以用它来计算多标签分类任务,即一幅图片可以同时具有多个类别标签,如大象和狗

其 logistic loss 的计算方式如下,其中 x = logits, z = labels

z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))

在 x < 0 的情况下,为了避免计算 exp(-x) 移除,将会按照如下方式计算 logistic 损失

x - x * z + log(1 + exp(-x))
= log(exp(x)) - x * z + log(1 + exp(-x))
= - x * z + log(1 + exp(x))

因此,为了确保计算稳定性并避免移除,最终采用下面的等式来计算 logistic 损失

max(x, 0) - x * z + log(1 + exp(-abs(x)))

参数:

_sentinel:用于保护位置参数,内部的,不使用 

labels:和 logits 具有相同类型和形状的张量

logits:类型为 float 32 或 float64 的张量

name:可选参数,操作的名称

 

返回:

具有分量式的 logistic losses 且形状与 logistic 相同的张量

 

三、实例

>>> import tensorflow as tf>>> logits = tf.constant(value=[[[0.1,0.2,0.3],[0.4,0.5,0.6],[0.7,0.8,0.9]],[[-0.1,-0.2,-0.3],[-0.4,-0.5,-0.6],[-0.7,-0.8,-0.9]]], dtype=tf.float32)
>>> logits
<tf.Tensor 'Const:0' shape=(2, 3, 3) dtype=float32>>>> labels = tf.constant(value=[[[1.0,1.0,0.0],[0.0,0.0,1.0],[1.0,1.0,1.0]],[[0.0,0.0,0.0],[0.0,0.0,0.0],[1.0,0.0,1.0]]],dtype=tf.float32)
>>> labels
<tf.Tensor 'Const_2:0' shape=(2, 3, 3) dtype=float32>>>> results = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
>>> results
<tf.Tensor 'logistic_loss:0' shape=(2, 3, 3) dtype=float32>>>> sess = tf.InteractiveSession()>>> print(sess.run(results))
[[[0.6443967  0.59813887 0.8543552 ][0.91301525 0.974077   0.43748793][0.40318602 0.3711007  0.34115386]][[0.6443967  0.59813887 0.5543552 ][0.5130153  0.474077   0.43748793][1.103186   0.3711007  1.2411538 ]]]>>> sess.close()

 

  相关解决方案