当前位置: 代码迷 >> 综合 >> torch.nn.CrossEntropy()的输入输出
  详细解决方案

torch.nn.CrossEntropy()的输入输出

热度:99   发布时间:2024-02-27 09:25:04.0

问题:torch.nn中的CrossEntropyLoss( )函数的输入输出?
回答:通过阅读PyTorch的官方文档可知,CrossEntropyLoss()函数是结合了两个函数的函数类,分别是nn.LogSoftmax()和nn.NLLLoss()。可得该函数先对输入的Tensor求其LogSoftmax,得到Softmax结果的对数值,再对该结果与标签求解NLLLoss()(计算公式:loss(input, class) = -input[class]),最后求解平均值得到CrossEntropyLoss。
输入:样本个数×特征值。(该特征值必须大于标签索引)
目标:样本个数。(其值为标签索引)
输出:输入与目标的CrossEntropyLoss。
具体如官方文档所示:

 

  相关解决方案