在介绍nn.Sequential和nn.ModuleDict之前,我们需要知道在pytorch构建的model核心是nn.Module模块,下面举个例子
class model(nn.Module):def __init__(self):super(Model, self).__init__()self.conv = nn.Conv2d(3, 20, 5)def forward(self, x):x = F.relu(conv(x))return x
在了解这个基本概念之后,我们分别介绍这两个模块
nn.Sequential
nn.Sequential继承自nn.Module模块,因此他自带forward函数,下面我们看一个例子
model = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())
print(model)
'''
Sequential((0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))(1): ReLU()(2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))(3): ReLU()
)
'''# 给每一步的模块进行命名
model = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1,20,5)),('relu1', nn.ReLU()),('conv2', nn.Conv2d(20,64,5)),('relu2', nn.ReLU())]))
print(model)
'''
Sequential((conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))(relu1): ReLU()(conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))(relu2): ReLU()
)
'''input = torch.randn([1, 1, 10, 10])
output = model(input)
print(output.size()) # torch.Size([1, 64, 2, 2])
如上所示,我们可以得到一些结论
- 在nn.Sequential里面的每一个操作是逐步执行的,不可改变顺序,如果第一步的输出与第二步的输入不匹配就会报错
- 可以通过OrderedDict来改变nn.Sequential里面每一步的名字。注意,即使改变了名字,索引时也需要用0,1,2…,例如model[0]=Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1),model[‘conv1’]会报错
nn.ModuleList
nn.ModuleDict没有继承自nn.Module,所以不能像nn.Sequential那样有forward功能。可以将其看做一个列表的形式,能够将多个操作存放在一个列表里
class MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])def forward(self, x):# ModuleList can act as an iterable, or be indexed using intsfor i, l in enumerate(self.linears):x = self.linears[i](x)return xmodel = MyModule()
input = torch.randn([1, 10])
output = model(input)
print(output.size()) # torch.Size([1, 10])
如上所示,这里总结nn.ModelList的一些特点
- nn.ModelList是单纯的列表形式,当我们想快速构建一些操作(例如例子中的linear操作时,可以使用modellist)
- nn.ModelList不具备forward功能,所以我们调用里面的操作时,需要进行索引,然后才能运行这个操作
- nn.ModelList列表内的操作可以是乱序的,比如我先用list[3],再用list[0],而nn.Sequential的执行顺序不能打乱
为什么不能用python中的list来代替nn.ModelList呢?
因为nn.ModelList可以将里面的列表操作自动注册到整个网络中,但是如果是python的list,则会出问题,如下
class net_modlist(nn.Module):def __init__(self):super(net_modlist, self).__init__()self.modlist = nn.ModuleList([nn.Conv2d(1, 20, 5),nn.Conv2d(20, 64, 5),])def forward(self, x):for m in self.modlist:x = m(x)return xmodel = net_modlist()
for param in model.parameters():print(type(param.data), param.size())'''
nn.ModuleList
<class 'torch.Tensor'> torch.Size([20, 1, 5, 5])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([64, 20, 5, 5])
<class 'torch.Tensor'> torch.Size([64])将nn.ModuleList换为单纯的list
None # 输出为None,表示conv操作并没有加入到模型参数中
'''