Detectron 源码解析-日志输出控制及训练状态跟踪

TrainingStats 类

training_stats = TrainingStats(model)

在训练文件 detectron/utils/train.py 中创建(create_model())并配置(setup_model_for_training)完模型以后, 调用了位于detectron/utils/training_stats.pyclass TrainingStats()类, 调用语句如下所示:

1
2
3
4
5
6
# detectron/utils/train.py

def train_model():
#...
training_stats = TrainingStats(model)
#...

下面, 来看看这个类具体的内部实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# detectron/utils/training_stats.py

# 该类用于跟踪关键的训练状态统计值
class TrainingStats:

def __init__(self, model):
# TODO 这个window size是指什么? smoothing tracked values??
self.WIN_SZ = 20
# 输出logging的iterations间隔
self.LOG_PERIOD = 20
self.smoothed_losses_and_metrics = {
# from detectron.utils.logging import SmoothedValue
# 该类用于跟踪一系列值, 同时提供访问smoothed values的借口(基于WIN_SZ或者global series average).
key: SmoothedValue(self.WIN_SZ)
for key in model.losses + model.metrics
}
self.losses_and_metrics = {
key: 0
for key in model.losses + model.metrics
}
self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
self.smoothed_mb_qsize = SmoothedValue(self.WIN_SZ)
self.iter_total_loss = np.nan
self.iter_timer = Timer() # from detectron.utils.timer import Timer
self.model = model
#...

SmoothedValue 类

从上面的代码可以看到, TrainingStats 类中的成员大多为SmoothedValue类对象, 该类的定义位于detectron/utils/logging.py 中, 下面先来看看这个文件的内部实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# detectron/utils/logging.py

def log_json_stats(stats, sort_keys=True):
# ...

class SmoothedValue(object):
# 该类用于跟踪一系列值, 同时提供访问滑动值smoothed values的借口(基于WIN_SZ或者global series average).
def __init__(self, window_size):
# from collections import deque
self.deque = deque(maxlen = window_size)
self.series = []
self.total = 0.0
self.count = 0
def AddValue(self, value):
# 将指定的值value加入到对象中的各个成员变量中
self.deque.append(value)
self.series.append(value)
self.total += value
self.count += 1
def GetMedianValue(self):
return np.median(self.deque) # axis为None , 则按照一维数组来计算deque中的中位数
def GetAverageValue(self):
return np.mean(self.deque) # 同理, 返回deque的平均值
def GetGlobalAverageValue(self):
return self.total / self.count # 返回所有值的平均值, 而不仅仅只是窗口内的平均值

#...

从上面的代码我们可以看出, 实际上SmoothedValue 类是用来维护滑动平均值的, 同时还会维护一个滑动中位数和总平均值.

Timer()

接着, 在初始化函数中, class TrainingStats类的成员变量 self.iter_timerclass Timer的类对象, 该类位于detectron/utils/timer.py文件中, 主要封装了python的time模块, 下面具体看一下实现细节

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
detectron/utils/timer.py

class Timer(object):
def __init__(self):
self.reset() # 调用类自身的reset()函数

def tic(self):
# 这里使用了time.time 而不是 time.clock, 原因是因为time.clock对于多线程任务来说可能存在一些问题
self.start_time = time.time() # 成员变量 start_time

def toc(self, average=True):
self.diff = time.time() - self.start_time # diff的值为当前时间与开始时间之间的间隔(单位为秒)
self.total_time += self.diff # 每调用一次toc函数, totaltime都会统计一次时间间隔
self.calls += 1 # 记录调用toc的次数
self.average_time = self.total_time / self.calls
if average:
return self.average_time
else:
return self.diff

def reset(self): # 将Timer内统计时间全部归0., 注意是浮点类型
self.total_time = 0.
self.calls = 0.
self.start_time = 0.
self.diff = 0.
self.average_time = 0.

再看 TrainingStats 类

接下来, 我们继续看 TrainingStats 类中的其他方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# detectron/utils/timer.py

class TrainingStats(object):

def __init__(self, model):
#...

def IterTic(self):
self.iter_timer.tic() # 调用Timer类的tic方法, 记录当前time.time()时间
def IterTic(self):
return self.iter_timer.toc(average=False) # 返回距离上次掉要tic方法的时间间隔(单位为秒)
def ResetIterTimer(self):
self.iter_timer.reset() # 重置所有时间相关的统计数据
def UpdateIterStats(self):
# 更新跟踪的迭代统计信息
for k in self.losses_and_metrics.keys():
if k in self.model.losses:
# import detectron.utils.net as nu
self.losses_and_metrics[k] = nu.sum_multi_gpu_blob(k) # 计算多个gpu上的数据和
else:
self.losses_and_metrics[k] = nu.average_multi_gpu_blob(k)
for k, v in self.smoothed_losses_and_metrics.items():
v.AddValue(self.losses_and_metrics[k])
self.iter_total_loss = np.sum(
np.array([self.losses_and_metrics[k] for k in self.model.losses])
)
self.smoothed_total_loss.AddValue(self.iter_total_loss)
self.smoothed_mb_qsize.AddValue(
self.model.roi_data_loader._minibatch_queue.qsize()
)
def LogIterStats(self, cur_iter, lr):
# 记录跟踪的统计信息
if(cur_iter % self.LOG_PERIOD == 0 or
cur_iter == cfg.SOLVER.MAX_ITER - 1):
stats = self.GetStats(cur_iter, lr)
log_json_stats(stats)
def GetStats(self, cur_iter, lr):
eta_seconds = self.iter_timer.average_time * (
cfg.SOLVER.MAX_ITER - cur_iter
) # 剩余时间
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
mem_stats = c2_py_utils.GetGPUMemoryUsageStats()
mem_usage = np.max(mem_stats['max_by_gpu'][:cfg.NUM_GPUS])
stats = dict(
iter=cur_iter,
lr=float(lr),
time=self.iter_timer.average_time,
loss=self.smoothed_total_loss.GetMedianValue(),
eta=eta,
mb_qsize=int(
np.round(self.smoothed_mb_qsize.GetMedianValue()),
),
mem=int(np.ceil(mem_usage / 1024 / 1024)) # 将字节转换成GB
)
for k, v in self.smoothed_losses_and_metrics.items():
stats[k] = v.GetMedianValue()
return stats