文章目录
-
-
-
- 1.torch softmax计算出0错误:
- 2.解决办法:
-
-
1.torch softmax计算出0错误:
import torch
entropy_softmax = torch.nn.Softmax(dim=0)t = torch.tensor([ 1.4588, 6.8605, -12.9596, 0.9334, 20.7457, -6.9825, 1.0430,
35.1390, -5.0085, 11.4986, 14.0365, -24.9701, 9.1256, 15.3329,
-5.0959, 83.0129, -13.2779, -8.6666, -20.7425, 37.3073, -6.4893,
3.1912, -10.2705, 14.7957, -27.7156, -7.7740, 9.0226, 13.1010,
8.0412, 10.2651, -2.7498, -13.8532, -20.4291, -6.6348, 5.6659,
7.9520, -3.7167, 23.1547, -10.2206, -4.7268, 8.1727, -5.3533,
-6.0071, 9.7224, -2.1432, -13.7613, 10.4201, 22.4950, -28.7874,
1.0044, -39.6582, -2.7599, -21.4416, 10.8764, 22.4835, 9.5556,
1.4402, -9.2514, 18.1716, -14.8400, 12.9393, 0.4393, 9.4990,
-13.0462, 15.9835, 18.0629, -0.2483, -0.7381, -13.4142, 17.9653,
-11.5424, -6.1927, -6.4485, 5.7864, -0.6982, -5.3879, 8.6595,
42.4957, -24.4659, 11.0456, 4.8009, 7.1376, 1.2844, 0.9651,
5.3984, -9.0878, -22.7269, -8.1914, -11.8171, -3.0764, 1.5658,
-1.3196, -14.6754, -3.2329, -0.7274, 12.0519, -11.1765, -16.8589,
-9.1866, -7.3975, 1.6408, 7.1755, 12.8522, 4.8147, 21.5046,
-8.0040, -0.6566, 14.9312, -24.6040, 20.5096, 8.7695, 8.5419,
17.2766, -4.2321, 6.5878, -21.8119, 5.4491, -8.9113, -0.1210,
12.1316, 2.4011, -0.5857, 3.5494, -8.4606, 8.9213, -16.8446,
-1.8504, 3.4145, -10.5080, -14.1006, -6.5070, -10.7599, 8.6234,
-0.4581, -11.4358, -7.7651, -14.2305, 5.5419, 7.3315, -16.9072,
17.5465, -15.2809, -15.2424, 0.8810, 1.3355, -4.9660, 9.5867,
15.8635, 23.4751, -14.8507])
a = entropy_softmax(t)
print(a)
tensor([3.8151e-36, 8.4612e-34, 2.0879e-42, 2.2559e-36, 9.0719e-28, 8.2318e-40,
2.5172e-36, 1.6167e-21, 5.9264e-39, 8.7445e-32, 1.1064e-30, 0.0000e+00,
8.1499e-33, 4.0453e-30, 5.4304e-39, 1.0000e+00, 1.5190e-42, 1.5279e-40,
1.4013e-45, 1.4135e-20, 1.3480e-39, 2.1571e-35, 3.0728e-41, 2.3640e-30,
0.0000e+00, 3.7303e-40, 7.3522e-33, 4.3416e-31, 2.7555e-33, 2.5470e-32,
5.6719e-38, 8.5479e-43, 1.4013e-45, 1.1655e-39, 2.5623e-34, 2.5204e-33,
2.1568e-38, 1.0091e-26, 3.2300e-41, 7.8547e-39, 3.1428e-33, 4.1980e-39,
2.1832e-39, 1.4803e-32, 1.0403e-37, 9.3607e-43, 2.9740e-32, 5.2169e-27,
0.0000e+00, 2.4219e-36, 0.0000e+00, 5.6149e-38, 0.0000e+00, 4.6937e-32,
5.1572e-27, 1.2529e-32, 3.7448e-36, 8.5139e-41, 6.9148e-29, 3.1809e-43,
3.6934e-31, 1.3764e-36, 1.1839e-32, 1.9142e-42, 7.7535e-30, 6.2026e-29,
6.9202e-37, 4.2404e-37, 1.3256e-42, 5.6258e-29, 8.6124e-42, 1.8134e-39,
1.4041e-39, 2.8904e-34, 4.4130e-37, 4.0553e-39, 5.1136e-33, 2.5328e-18,
0.0000e+00, 5.5590e-32, 1.0788e-34, 1.1163e-33, 3.2045e-36, 2.3286e-36,
1.9609e-34, 1.0027e-40, 0.0000e+00, 2.4574e-40, 6.5441e-42, 4.0916e-38,
4.2459e-36, 2.3706e-37, 3.7555e-43, 3.4988e-38, 4.2860e-37, 1.5207e-31,
1.2418e-41, 4.2039e-44, 9.0838e-41, 5.4358e-40, 4.5766e-36, 1.1594e-33,
3.3853e-31, 1.0938e-34, 1.9377e-27, 2.9639e-40, 4.6004e-37, 2.7070e-30,
0.0000e+00, 7.1641e-28, 5.7082e-33, 4.5463e-33, 2.8255e-29, 1.2882e-38,
6.4417e-34, 0.0000e+00, 2.0629e-34, 1.1963e-40, 7.8597e-37, 1.6468e-31,
9.7890e-36, 4.9384e-37, 3.0863e-35, 1.8774e-40, 6.6440e-33, 4.3440e-44,
1.3942e-37, 2.6968e-35, 2.4231e-41, 6.6702e-43, 1.3243e-39, 1.8836e-41,
4.9323e-33, 5.6105e-37, 9.5821e-42, 3.7637e-40, 5.8574e-43, 2.2635e-34,
1.3551e-33, 4.0638e-44, 3.7009e-29, 2.0459e-43, 2.1300e-43, 2.1408e-36,
3.3725e-36, 6.1837e-39, 1.2924e-32, 6.8768e-30, 1.3901e-26, 3.1529e-43])
print(torch.nn.functional.softmax(t, 0))
tensor([3.8151e-36, 8.4612e-34, 2.0879e-42, 2.2559e-36, 9.0719e-28, 8.2318e-40,
2.5172e-36, 1.6167e-21, 5.9264e-39, 8.7445e-32, 1.1064e-30, 0.0000e+00,
8.1499e-33, 4.0453e-30, 5.4304e-39, 1.0000e+00, 1.5190e-42, 1.5279e-40,
1.4013e-45, 1.4135e-20, 1.3480e-39, 2.1571e-35, 3.0728e-41, 2.3640e-30,
0.0000e+00, 3.7303e-40, 7.3522e-33, 4.3416e-31, 2.7555e-33, 2.5470e-32,
5.6719e-38, 8.5479e-43, 1.4013e-45, 1.1655e-39, 2.5623e-34, 2.5204e-33,
2.1568e-38, 1.0091e-26, 3.2300e-41, 7.8547e-39, 3.1428e-33, 4.1980e-39,
2.1832e-39, 1.4803e-32, 1.0403e-37, 9.3607e-43, 2.9740e-32, 5.2169e-27,
0.0000e+00, 2.4219e-36, 0.0000e+00, 5.6149e-38, 0.0000e+00, 4.6937e-32,
5.1572e-27, 1.2529e-32, 3.7448e-36, 8.5139e-41, 6.9148e-29, 3.1809e-43,
3.6934e-31, 1.3764e-36, 1.1839e-32, 1.9142e-42, 7.7535e-30, 6.2026e-29,
6.9202e-37, 4.2404e-37, 1.3256e-42, 5.6258e-29, 8.6124e-42, 1.8134e-39,
1.4041e-39, 2.8904e-34, 4.4130e-37, 4.0553e-39, 5.1136e-33, 2.5328e-18,
0.0000e+00, 5.5590e-32, 1.0788e-34, 1.1163e-33, 3.2045e-36, 2.3286e-36,
1.9609e-34, 1.0027e-40, 0.0000e+00, 2.4574e-40, 6.5441e-42, 4.0916e-38,
4.2459e-36, 2.3706e-37, 3.7555e-43, 3.4988e-38, 4.2860e-37, 1.5207e-31,
1.2418e-41, 4.2039e-44, 9.0838e-41, 5.4358e-40, 4.5766e-36, 1.1594e-33,
3.3853e-31, 1.0938e-34, 1.9377e-27, 2.9639e-40, 4.6004e-37, 2.7070e-30,
0.0000e+00, 7.1641e-28, 5.7082e-33, 4.5463e-33, 2.8255e-29, 1.2882e-38,
6.4417e-34, 0.0000e+00, 2.0629e-34, 1.1963e-40, 7.8597e-37, 1.6468e-31,
9.7890e-36, 4.9384e-37, 3.0863e-35, 1.8774e-40, 6.6440e-33, 4.3440e-44,
1.3942e-37, 2.6968e-35, 2.4231e-41, 6.6702e-43, 1.3243e-39, 1.8836e-41,
4.9323e-33, 5.6105e-37, 9.5821e-42, 3.7637e-40, 5.8574e-43, 2.2635e-34,
1.3551e-33, 4.0638e-44, 3.7009e-29, 2.0459e-43, 2.1300e-43, 2.1408e-36,
3.3725e-36, 6.1837e-39, 1.2924e-32, 6.8768e-30, 1.3901e-26, 3.1529e-43])
print(0.0000e+00 == 0)
True
2.解决办法:
if val == 0.0000e+00:val = 1e-50