用PyTorch实现经典VGG网络

首先, 来看一下原文中关于 VGG 网络的结构设置, 如下图所示:

可以看到, 上图中, 不同版本的 VGG 网络的整体结构差不多, 主要的不同体现在每一个卷积段内(共5个卷积段)卷积层的个数以及卷积层的参数, 下面我们以 VGG-19 为例, 给出 VGG 网络的 PyTorch 实现, 其他版本的 VGG 网络可以用同样方式进行定义.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

import torch
import torch.nn as nn
class VGGNet(nn.Module):
def __init__(self, num_classes):
super(VGGNet, self).__init__()
self.num_classes = num_classes
self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
self.relu1_1 = nn.ReLU(inplace=True)
self.conv1_2 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
self.relu1_2 = nn.ReLU(inplace=True)
self.max1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
self.relu2_1 = nn.ReLU(inplace=True)
self.conv2_2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
self.relu2_2 = nn.ReLU(inplace=True)
self.max2 = nn.MaxPool2d(kernel_size=2, stride=2)
#...TODO 如此定义19个层

上面的定义方式比较直观, 但是不够简洁, 由于 VGGNet 的结构设计比较有规律, 因此我们可以用下面的代码使模型定义变的更加整洁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# vgg16, 可以看到, 带有参数的刚好为16个
net_arch16 = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M5', "FC1", "FC2", "FC"]

# vgg19, 基本和 vgg16 相同, 只不过在后3个卷积段中, 每个都多了一个卷积层
net_arch19 = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M5', "FC1", "FC2", "FC"]

import torch
import torch.nn as nn
class VGGNet(nn.Module):
def __init__(self, net_arch, num_classes):
# net_arch 即为上面定义的列表: net_arch16 或 net_arch19
super(VGGNet, self).__init__()
self.num_classes = num_classes
layers = []
in_channels = 3 # 初始化通道数
for arch in net_arch:
if arch == 'M':
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
elif arch == 'M5':
layers.append(nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
elif arch == "FC1":
layers.append(nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6))
layers.append(nn.ReLU(inplace=True))
elif arch == "FC2":
layers.append(nn.Conv2d(1024,1024, kernel_size=1))
layers.append(nn.ReLU(inplace=True))
elif arch == "FC":
layers.append(nn.Conv2d(1024,self.num_classes, kernel_size=1))
else:
layers.append(nn.Conv2d(in_channels=in_channels, out_channels=arch, kernel_size=3, padding=1)
layers.append(nn.ReLU(inplace=True))
in_channels=arch
self.vgg = nn.ModuleList(layers)
def forward(self, input_data):
x = input_data
for layer in self.vgg:
x = layer(x)
out = x
return out

通过此方式定以后, 模型的 forward 部分非常简洁, 也很易于理解.