当前位置: 代码迷 >> 综合 >> Pytorch中torch.nn.Linear()的分析
  详细解决方案

Pytorch中torch.nn.Linear()的分析

热度:108   发布时间:2023-11-25 14:01:02.0

torch.nn.Linear(in_featuresout_featuresbias=Truedevice=Nonedtype=None)

  • in_features – 输入样本的大小

  • out_features – 输出样本的大小

  • bias – 如果设置为False,该层将不会学习偏移量。默认值:True

nn.Linear()的作用是对传入的数据进行线性变换。

 

 传入的数据是x,经过变换后得到y,A是权重,b是偏移量。

查看官方文档,最开始生成的权重矩阵A的shape为(out_features, in_features)

m = nn.Linear(20,30)
# 查看生成的权重矩阵的shape,会发现权重的shape为30*20
print(m.weight.size())# 生成一个128*20的矩阵
input = torch.randn(128, 20)
# 查看生成矩阵的大小
print(input.size())# 传入input(shape为128*20)后,input和权重矩阵A转置后的矩阵
# 相乘(转置后A的shape变为20*30),最终得到的结果shape为128*30
print(m(input).size())

程序运行生成的结果