学习一个算法最好的方式就是自己尝试着去实现它! 因此, 在这片博文里面, 我会为大家讲解如何用PyTorch从零开始实现一个YOLOv3目标检测模型, 参考源码请在这里下载.
在正式介绍 YOLOv3 之前, 我们先将其和 YOLO 的其他版本做一个简单的比较, 它们的网络结构对比如下所示:
这里我们假设大家对YOLOv3的各个细节都比较熟悉, 因此就不对YOLOv3做过多介绍, 如果对YOLOv3不太懂的话, 可以再看看原文, 或者看看我写的YOLOv3解析.
模型实现总共会分为以下六部分:
- (一) 配置文件以及解析
- (二) 搭建YOLO模型框架
- (三) 实现自定义网络层的前向和反向传播过程
- (四) 数据类的设计与实现
- (五) 训练/测试/检测脚本的实现
- (六) 辅助函数及算法实现(目标函数, NMS算法等)
(一) 配置文件以及解析
配置文件
官方代码使用了配置文件来创建网络, cfg
文件中描述了网络的整体结构, 它相当于 caffe 中的 .protxt
文件一样. 我们也将使用官方的 cfg
文件来创建我们的网络, 点击这里下载并它放在 config/
文件夹中, 即 config/yolov3.cfg
.
1 | mkdir config |
打开该文件, 将会看到类似于下面的信息:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28[convolutional]
batch_normalize=1
filters=64
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=32
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
...
convolutional 和 shortcut
上面的信息中显示了4个 block, 其中 3 个是卷积网络层, 最后一个是 shortcut 网络层, shortcut 网络层是一种 skip connection, 就像 ResNet 中的一样, 其中的 from
参数为 -3
表示该层的输出是从往前数倒数第三层的图谱 直接相加 得到的.
upsample
cfg
文件中的 upsample
参数代表了双线性插值时使用的 stride
参数1
2[upsample]
stride=2
route
route
参数拥有 layers
属性, 它的值可以是一个, 也可以是两个, 如下所示. 当 layers
属性只含有一个值时, 它会输出指定的网络层的特征图谱, 在下面的例子中, layers=-4
, 因此, 当前的 route
网络层会输出前面的倒数第 4 个网络层的特征图谱. 当 layers
属性含有两个值时, 它会输出两个网络层的特征图谱连接(concatenated)后的特征图谱, 在下面的例子中, 当前的 route
网络层会将前一层(-1)和第 61 层的特征图片沿着深度维度(depth dimension)进行连接(concatenated), 然后输出连接后的特征图谱.1
2
3
4
5[route]
layers = -4
[route]
layers = -1, 61
net
cfg
文件中的另一种 block 类型是 net
, 它提供了网络的训练信息, 如下所示:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16[net]
# Testing
batch=1
subdivisions=1
# Training
# batch=64
# subdivisions=16
width= 320
height = 320
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
解析配置文件
我们定义了一个名为 parse_config.py
的文件, 其内部的 parse_model_config()
函数的参数是指定的 cfg
的文件路径, 它的功能是将 cfg
文件中的信息加载到模型中, 并且用 元素为字典的列表 的形式进行存储, 如下所示:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18# ./utils/parse_config.py
def parse_model_config(path):
f = open(path, 'r') #读取文件
module_defs = [] # 创建列表, 列表中的元素为字典
for line in f.readlines(): # 逐行读取
line = line.strip() # 消除行头尾的空白符(空格, 回车等)
if not line or line.startswith('#'): # 如果遇到空行或者注释行, 则跳过
continue
if line.startswith('['):# 遇到模块的起始, 在列表后添加新的字典
module_defs.append({})
module_defs[-1]['type'] = line[1:-1].strip() # 根据参数值为字典赋值
if(module_defs[-1]['type']=="convolutional"):
module_defs[-1]["batch_normalize"] = 0
else:
key, value = line.split('=')# 根据参数值为字典赋值, 注意要去除空白符
module_defs[-1][key.strip()] = value.strip()
return module_defs
调用该函数后, 会返回一个列表, 列表中的每个元素都是一个字典, 代表了配置文件中的以 [...]
开头的一个 block, 下面是列表中的部分元素示例:1
2
3model_config = parse_model_config("../config/yolov3-tiny.cfg")
print(model_config[0])
print(model_config[1])
输出如下:1
2
3{'channels': '3', 'hue': '.1', 'batch': '1', 'steps': '400000,450000', 'burn_in': '1000', 'max_batches': '500200', 'learning_rate': '0.001', 'exposure': '1.5', 'policy': 'steps', 'height': '416', 'width': '416', 'subdivisions': '1', 'angle': '0', 'type': 'net', 'scales': '.1,.1', 'momentum': '0.9', 'decay': '0.0005', 'saturation': '1.5'}
{'stride': '1', 'activation': 'leaky', 'type': 'convolutional', 'filters': '16', 'pad': '1', 'size': '3', 'batch_normalize': '1'}
(二) 数据类的设计与实现
在搭建 YOLO 模型之前, 我们需要先创建处理数据输入的类, 在 PyTorch 中, 通常是通过集成 torch.utils.data.Dataset
类来实现的, 我们需要实现该类的 __getitem__()
和 __len__()
方法, 实现后, 会将子类的实例作为 DataLoader
的参数, 来构建生成 batch 的实例对象. 下面, 先只给出有关数据集类的实现, 具体的加载过程在后续的脚本解析中给出.
class ImageFolder(Dataset) 类
这里我们起名为 ImageFolder
, 主要是因为原作者使用了这个名字, 实际上我不太建议使用这个名字, 因为会与 PyTorch 中 ImageFolder
类的名字冲突, 容易引起误会, 这里注意一下, 我们这里实现的 ImageFolder
类与 PyTorch 中的同名类并没有任何联系. 代码解析如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32# ./utils/datasets.py
class ImageFolder(Dataset):
def __init__(self, folder_path, img_size=416):
# 获取文件夹下的所有图片路径, glob是一个用于获取路径的通配符模块
self.files = sorted(glob.glob('%s/*.*' % folder_path))
# 设置数据集的图片大小属性, 所有的图片都会被放缩到该尺寸
self.img_shape = (img_size, img_size)
def __getitem__(self, index):
img_path = self.files[index % len(self.files)] # 根据index获取图片路径
# Extract image
img = np.array(Image.open(img_path)) # 利用PIL Image读取图片, 然后转换成numpy数组
h, w, _ = img.shape # 获取图片的高和宽
dim_diff = np.abs(h - w) # 计算高宽差的绝对值
# 根据高宽差计算应该填补的像素数量(填补至高和宽相等)
pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
# 确定填补位置(填补到边长较短的一边)
pad = ((pad1, pad2), (0, 0), (0, 0)) if h <= w else ((0, 0), (pad1, pad2), (0, 0))
# 调用 np.pad 函数进行填补
input_img = np.pad(img, pad, 'constant', constant_values=127.5) / 255.
# 将图片放缩至数据集规定的尺寸, 同时进行归一化操作
input_img = resize(input_img, (*self.img_shape, 3), mode='reflect')
# 将通道维度放置在首位(C,H,W)
input_img = np.transpose(input_img, (2, 0, 1))
# 将numpy数组转换成tenosr, 数据类型为 float32
input_img = torch.from_numpy(input_img).float()
# 返回图片路径和图片 tensor
return img_path, input_img
def __len__(self):
return len(self.files)
class ListDataset(Dataset) 类
ListDataset
类定义了训练时所需的数据集和标签, 该类的 __getitem__()
方法会返回三个变量, 分别是: 图片路径, 经过放缩处理后的图片(尺寸大小为指定尺寸), 以及经过处理后的 box 坐标信息. 其中, 图片的存储形式为: $(C\times H\times W)$, 标签的存储形式为: $(50 \times 5)$, 这 50 条数据不一定每一条都具有意义, 对于无意义的数据, 其值为 0, 训练时直接跳过即可, 对于有意义的数据, 每一条数据的形式为: $(class_id, x, y, w, h)$, 其中, $class_id$ 是每个 box 对应的目标类别编号, $x, y, w, h$ 是每个 box 的中心点坐标和宽高, 它们都是以小数形式表示的, 也就是相对于图片宽高的比例.1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77# ./utils/datasets.py
class ListDataset(Dataset):
def __init__(self, list_path, img_size=416):
# list_path: data_config 文件中的 trian 或 val 指定的文件: trainvalno5k.txt 或者 5k.txt
# 该文件中存放了用于训练或者测试的.jpg图片的路径, 同时根据此路径可以得到对应的 labels 文件
with open(list_path, 'r') as file:
self.img_files = file.readlines()
# 根据图片的路径得到 label 的路径, label 的存储格式为一个图片对应一个.txt文件
# 文件的每一行代表了该图片的 box 信息, 其内容为: class_id, x, y, w, h (xywh都是用小数形式存储的)
self.label_files = [path.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') for path in self.img_files]
self.img_shape = (img_size, img_size) # 获取图片目标大小, 之后会将图片放缩到此大小, 并相应调整box的数据
self.max_objects = 50 # 定义每一张图片最多含有的 box 数量
def __getitem__(self, index):
# 根据index获取对应的图片路径
img_path = self.img_files[index % len(self.img_files)].rstrip()
img = np.array(Image.open(img_path))
# 如果当前获取到的图片的通道数不为3, 则跳过当前图片, 直到获取到通道数为3的图片
while len(img.shape)!=3:
index += 1
img_path = self.img_files[(index) % len(self.img_files)].rstrip()
img = np.array(Image.open(img_path))
# 获取图片的高和宽, 并根据它们的差异对图片执行 padding 操作, 使图片宽高比为1
h, w, _ = img.shape
dim_diff = np.abs(h - w)
pad1, pad2 = dim_diff//2, dim_diff - dim_diff//2
pad = ((pad1, pad2), (0,0), (0,0)) if h<=w else ((0,0), (pad1, pad2), (0,0))
input_img = np.pad(img, pad, 'constant', constant_values=128) / 255.
# 暂存padding后的图片的宽和高
padded_h, padded_w, _ = input_img.shape
# 将图片大小放缩到指定的存储, 并将通道数放置到高和宽之前
input_img = resize(input_img, (*self.img_shape, 3), mode='reflect')
input_img = np.transpose(input_img, (2,0,1))
# 将图片转化成 tensor
input_img = torch.from_numpy(input_img).float()
# 获取图片对应的 label 文件的路径
label_path = self.label_files[index % len(self.img_files)].rstrip()
labels = None
# 根据图片 padding 之后的存储, 对 label 文件中的 box 坐标按比例进行缩放
if os.path.exists(label_path):
labels = np.loadtxt(label_path).reshape(-1, 5)
x1 = w * (labels[:, 1] - labels[:, 3] / 2) # 先获取box左上角和右下角的像素坐标
y1 = h * (labels[:, 2] - labels[:, 4] / 2)
x2 = w * (labels[:, 1] + labels[:, 3] / 2)
y2 = h * (labels[:, 2] + labels[:, 4] / 2)
# 根据 padding 的大小, 更新这些坐标的值
x1 += pad[1][0]
y1 += pad[0][0]
x2 += pad[1][0]
y2 += pad[0][0]
# 重新将坐标转化成小数模式(相对应padding后的宽高的比例)
labels[:, 1] = ((x1+x2)/2) / padded_w
labels[:, 2] = ((y1+y2)/2) / padded_h
labels[:, 3] *= w / padded_w
labels[:, 4] *= h / padded_h
filled_labels = np.zeros((self.max_objects, 5)) # 创建50×5的占位空间
if labels is not None: # 将更新后的box坐标填充到刚刚申请的占位空间中
filled_labels[range(len(labels))[:self.max_objects]] = labels[:self.max_objects]
# 将 label 转化成 tensor
filled_labels =torch.from_numpy(filled_labels)
# 返回图片路径, 图片tensor, label tensor
return img_path, input_img, filled_labels
def __len__(self):
return len(self.img_files)
(三) 搭建YOLO模型框架
在 models.py
文件中, 定义了 YOLO 的模型框架, 文件概览及类之间的调用关系如下:
1 | # ./models.py |
create_modules() 函数
下面我们先来看看模型创建函数 create_modules
的代码解析:
1 | # ./models.py |
class EmptyLayer(nn.Module)
在上面的代码中, 对于 route
和 shortcut
使用了自定义的 class EmptyLayer(nn.Module)
, 该类主要起到一个占位符(placeholder)的作用, 其内部实现会根据模块的类型不同而有所区别, 下面是该类的定义:
1 | # ./models.py |
class YOLOLayer(nn.Module)
接着, 对于 yolo
模块, 使用了 class YOLOLayer(nn.Module)
, 该类的定义如下:
1 | # ./models.py |
上面 YOLOLayer
类的 forward()
函数使用了 build_targets()
函数来将真实的标签数据转化成训练用的格式, 关于该函数的解析可以看 utils.py 文件解析中的 build_target()函数
class Darknet(nn.Module)
1 | # ./models.py |
(四) 实现自定义网络层的前向和反向传播过程
(五) 训练/测试/检测脚本的实现
detect.py
该函数定义了模型的检测逻辑, 调用该函数, 会将图片送入模型中去运算, 并且会返回相应的预测结果, 然后, 需要对预测结果执行 NMS 算法, 消除重叠的框, 最后, 将预测结果以.png
的格式进行可视化存储.
1 | # ./detect.py |
train.py 训练脚本
1 | # ./train.py |
test.py 测试脚本
1 | # ./test.py |
(六) 辅助函数及算法实现(目标函数, NMS算法等)
utils.py
load_classes()
weights_init_normal()
compute_ap()
bbox_iou()
在 build_targets
函数中, 使用了 bbox_iou()
函数来计算两组 box 之间的 iou 大小, 代码实现逻辑如下所示:
1 | #./utils/utils.py |
bbox_iou_numpy()
non_max_suppression()
对预测的结果执行 NMS 算法, 传入的预测结果shape为: [1,10647,85], 最终会返回一个列表, 列表中的每个元素是每张图片的box组成的tensor, box的shape为: (x1, y1, x2, y2, object_conf, class_score, class_pred).
在 YOLO 中, 是对每一个类别(如80类)执行 NMS 算法. 而在 Faster R-CNN 中, 是对两个类进行 NMS 算法, 因此, 在 Faster R-CNN 中, 对于不同的类的 box, 如果它们的重叠度较高, 那么就会删除其中的一个.
1 | # ./utils/utils.py |
build_targets() 函数
该函数会根据 targets, anchors 以及预测的 box 来创建训练模型时使用的数据形式, 在 YOLO 中, 我们的训练目标不是直接的 box 坐标, 而是对其进行相应的编码, 然后在进行训练, 编码的方式如下所示, 数据的标注信息为 $(b_x, b_y, b_w, b_h)$, 而我们的训练目标是 $(t_x, t_y, t_w, t_h)$, 这两组数据可以互相转换.
1 | # ./utils/utils.py |