torch.nn.
Linear
(in_features, out_features, bias=True, device=None, dtype=None)
-
in_features – 输入样本的大小
-
out_features – 输出样本的大小
-
bias – 如果设置为False,该层将不会学习偏移量。默认值:True
nn.Linear()的作用是对传入的数据进行线性变换。
传入的数据是x,经过变换后得到y,A是权重,b是偏移量。
查看官方文档,最开始生成的权重矩阵A的shape为(out_features, in_features)
m = nn.Linear(20,30)
# 查看生成的权重矩阵的shape,会发现权重的shape为30*20
print(m.weight.size())# 生成一个128*20的矩阵
input = torch.randn(128, 20)
# 查看生成矩阵的大小
print(input.size())# 传入input(shape为128*20)后,input和权重矩阵A转置后的矩阵
# 相乘(转置后A的shape变为20*30),最终得到的结果shape为128*30
print(m(input).size())
程序运行生成的结果