在遇到一些稍微复杂的搭建模型的需求的时候,使用pytorch中的 nn.ModuleList() 和nn.Sequential()可以方便很多。
使用 ModuleList 可以简化写法。
这里需要讲的是,ModuleList 可以存储多个 model,传统的方法,一个model 就要写一个 forward ,但是如果将它们存到一个 ModuleList 的话,就可以使用一个 forward。
ModuleList是Module的子类,当在Module中使用它的时候,就能自动识别为子module。
当添加 nn.ModuleList 作为 nn.Module 对象的一个成员时(即当我们添加模块到我们的网络时),所有 nn.ModuleList 内部的 nn.Module 的 parameter 也被添加作为 我们的网络的 parameter。
class model2(nn.Module):def __init__(self):super(model2, self).__init__()self.layers=nn.ModuleList([nn.Linear(1,10), nn.ReLU(),nn.Linear(10,100),nn.ReLU(),nn.Linear(100,10),nn.ReLU(),nn.Linear(10,1)])def forward(self,x):out=xfor i,layer in enumerate(self.layers):out=layer(out)return out
其它用法
ModuleList 具有和List 相似的用法,实际上可以把它视作是 Module 和 list 的结合。
除了在创建 ModuleList 的时候传入一个 module 的 列表,还可以使用extend 函数和 append 函数来添加模型。
1.extend 方法
和 list 相似,参数为一个元素为 Module的列表,该方法的效果是将列表中的所有 Module 添加到 ModuleList中:
self.linears.extend([nn.Linear(size1, size2) for i in range(1, num_layers)])
2.append 方法
和list 的append 方法一样,将 一个 Module 添加到ModuleList。
self.linears.append(nn.Linear(size1, size2)
使用 nn.Sequential()
class model3(nn.Module):def __init__(self):super(model3, self).__init__()self.network=nn.Sequential(nn.Linear(1,10),nn.ReLU(),nn.Linear(10,100),nn.ReLU(),nn.Linear(100,10),nn.ReLU(),nn.Linear(10,1))def forward(self, x):return self.network(x)