源码文件
train_net.py 文件概览
为了更好的解读 MaskrcnnBenchmark 的源代码, 我们首先来看看执行模型训练代码的脚本文件都使用了哪些类和函数, 该脚本可以训练 MaskrcnnBenchmark 中的所有模型, 因此, 我们可以从该文件出发, 顺藤摸瓜的探索, 以期最终能够对整个 MaskrcnnBenchmark 框架有一个全面系统的了解, 那么接下来就先来改一下该文件的大致结构, 如下所示:
1 |
|
train_net 导入的各种包和函数
首先来看看该文件导入了那些包和函数, 同时我们会针对性的简单介绍一下它们的主要作用和功能.1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36# ./tools/train_net.py
# 下面的这个必须在导入所有包之前导入, 不能放到其他位置. TODO 原因
from maskrcnn_benchmark.utils.env import setup_environment
# 常规包
import argparse
import os
import torch
from maskrcnn_benchmark.config import cfg # 导入默认配置信息
from maskrcnn_benchmark.data import make_data_loader # 数据集载入
from maskrcnn_benchmark.solver import make_lr_scheduler # 学习率更新策略
from maskrcnn_benchmark.solver import make_optimizer # 设置优化器, 封装了PyTorch的SGD类
from maskrcnn_benchmark.engine.inference import inference # 推演代码
from maskrcnn_benchmark.engine.trainer import do_train # 模型训练的核心逻辑. 会重点解析
# 调用了 ./maskrcnn_benchmark/modeling/detector/ 中的 build_detection_model() 函数
# 该函数和 Detectron 中的类似, 都是用来创建目标检测模型的, 这也是创建模型的入口函数, 十分重要
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.checkpoint import DetectronCheckpointer
# 封装了 PyTorch 的 torch.utils.collect_env.get_preety_env_info 函数, 同时附加了 PIL.__version__ 版本新
from maskrcnn_benchmark.utils.collect_env import collect_env_info
# 分布式训练相关设置, 由于我的gpu个数为1, 因此 get_rank() 会返回 0
from maskrcnn_benchmark.utils.comm import synchronize, get_rank
from maskrcnn_benchmark.utils.imports import import_file
# 封装了 logging 模块, 用于向屏幕输出一些日志信息
from maskrcnn_benchmark.utils.logger import set_up_logger
# 封装了 os.mkdirs 函数, 当文件夹已存在时会自动略过, 不会提示错误
from maskrcnn_benchmark.utils.miscellaneous import mkdir
cfg
: 详情可看MaskrcnnBenchmark 默认配置make_data_loader
: 详情可看数组载入make_lr_scheduler
: 详情可看优化器及学习率更新策略make_optimizer
: 详情可看优化器及学习率更新策略inference
: 该函数是执行模型推演逻辑的核心代码, 具体的解析请看inferencedo_train
: 该函数是执行模型训练逻辑的核心代码, 具体的解析请看do_trainbuild_detection_model
: 详情可看模型创建DetectronCheckpointer
: 详情可看DetectronCheckpointcollect_env_info
: 详情可看collect_env_infosynchronize
和get_rank
: 详情可看commimport_file
: 详情可看importsset_logger
: 详情可看loggermkdir
: 详情可看mkdir
相比于 Detectron 来说, MaskrcnnBenchmark 的默认配置文件显得相当 “清爽”, 定义的配置项也很精简, 下面就是 cfg
的部分配置清单, 输入 print(cfg)
即可看到全部配置项.1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22DATALOADER:
ASPECT_RATIO_GROUPING: True
NUM_WORKERS: 4
SIZE_DIVISIBILITY: 0
DATASETS:
TEST: ()
TRAIN: ()
INPUT:
MAX_SIZE_TEST: 1333
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MIN_SIZE_TRAIN: 800
PIXEL_MEAN: [102.9801, 115.9465, 122.7717]
PIXEL_STD: [1.0, 1.0, 1.0]
TO_BGR255: True
MODEL:
BACKBONE:
CONV_BODY: R-50-C4
FREEZE_CONV_BODY_AT: 2
OUT_CHANNELS: 1024
DEVICE: cuda
# ...
train_net.main() 主函数
下面我们根据脚本的执行顺序, 先来看看主函数的代码:
1 | # ./tools/train_net.py |
train_net.train() 训练脚本
1 | def train(cfg, local_rank, distributed): |
在执行模型训练逻辑时, 该函数调用了 ./maskrcnn_benchmark/engine/trainer.py
文件中的 do_train()
函数, 该函数是执行训练逻辑的核心代码, 具体的解析请看do_train
我们注意到, 在代码的第一句开头使用了 ./maskrcnn_benchmark/modeling/detector/
文件夹下面的用来创建目标检测模型的函数, 这也是创建模型的入口函数, 十分重要. 关于该函数的详细解析可以查看讲解模型创建的博文, 这里只需要知道该函数会根据我们的配置文件返回一个模型就可以了, 如果想查看该模型的具体网络结构及其参数, 可以利用 print(model)
查看, 如下所示为当前当前配置信息下的部分结构信息:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21GeneralizedRCNN(
(backbone): Sequential(
(body): ResNet(
(stem): StemWithFixedBatchNorm(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): FrozenBatchNorm2d()
)
(layer1): Sequential(
(0): BottleneckWithFixedBatchNorm(
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): FrozenBatchNorm2d()
)
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d()
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d()
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d()
)
#...
train_net.test() 推演脚本
1 | def test(cfg, model, distributed): |
在执行模型的推演逻辑时, 该函数调用了 ./maskrcnn_benchmark/engine/inference.py
文件中的 inference()
函数, 该函数是执行推演逻辑的核心代码, 具体的解析请看inference
test_net.py 文件概览
1 | # ./tools/test_net.py |
test_net 导入的各种包及函数
我们首先看看该文件导入的包及函数
1 | # ./tools/test_net.py |
可以看出, 在 ./tools/test_net.py
文件中导入的包和函数和 ./tools/train_net.py
差不多, 我们已经在之前简单介绍了这些包的功能和用途, 并给出了详细解析的链接, 这里我们就不再重复介绍, 有疑惑的可以翻到上面去看关于 train_net.py
导入的包及函数的解析.
test_net.main() 主函数
由于在进行模型推演时, 我们只需要准备好预训练文件, 数据集, 以及模型结构就可以完成整个推演过程, 因此在 test_net.py
脚本中只用了一个主函数来完成这些功能, 下面我们就来看看这个主函数的具体实现吧.
1 | # ./tools/test_net.py |