pytorch 使用 @ 做矩阵点乘时发生错误:
TypeError: unsupported operand type(s) for @: ‘numpy.ndarray’ and ‘Tensor’
代码
import torch
weights = torch.randn(784, 10) / math.sqrt(784)
weights.requires_grad_()x @ weights
原因: 你需要先把左边的x转换成tensor,才能和右边的权重相乘
解决:
x, = map (torch.tensor, (x, )
)
很简单
https://pytorch.org/tutorials/beginner/nn_tutorial.html