代码:
target.sum(axis=0)
报错:
Traceback (most recent call last):File "/home/user1/test.py", line 101, in <module>t = target.sum(axis=0).reshape((1, 40))
TypeError: sum() received an invalid combination of arguments - got (axis=int, ), but expected one of:* ()didn't match because some of the keywords were incorrect: axis* (torch.dtype dtype)didn't match because some of the keywords were incorrect: axis* (tuple of ints dim, torch.dtype dtype)* (tuple of ints dim, bool keepdim, torch.dtype dtype)* (tuple of ints dim, bool keepdim)
原因:你操作的是一个torch.tensor,需要先转成numpy array才能进一步求和。
解决:
target = target.numpy() # convert torch.tensor to numpy array
感谢:Shai @ https://stackoverflow.com/questions/54999926/pytorch-typeerror-eq-received-an-invalid-combination-of-arguments