问题:torch.nn中的CrossEntropyLoss( )函数的输入输出?
回答:通过阅读PyTorch的官方文档可知,CrossEntropyLoss()函数是结合了两个函数的函数类,分别是nn.LogSoftmax()和nn.NLLLoss()。可得该函数先对输入的Tensor求其LogSoftmax,得到Softmax结果的对数值,再对该结果与标签求解NLLLoss()(
计算公式:loss(input, class) = -input[class]),最后求解平均值得到CrossEntropyLoss。
输入:样本个数×特征值。(该特征值必须大于标签索引)
目标:样本个数。(其值为标签索引)
输出:输入与目标的CrossEntropyLoss。
具体如官方文档所示: