SSD 源码实现 (PyTorch)-数据处理

源码文件

  • ./data/: coco.py,
  • ./utils/: augmentations.py

augmentations.py 文件概览

众多周知, SSD 模型虽然比较简单, 但是也因此在精度上不够优秀, 故而需要借助较多的 Augmentation Trick 来提升模型的 mAP, 这部分代码位于 utils/augmentations.py 文件中, 由于这部分代码比较琐碎, 并且与 SSD 网络的关系并不大, 所以这里我们只给出一个整体概览, 不做过多注释, 有兴趣的朋友可以自己查看源码. 代码文件内容概览如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# ./utils/augmentations.py

def intersect(box_a, box_b):
# ...
def jaccard_numpy(box_a, box_b):
# ...
class Compose(object):
# ...
class Lambda(object):
# ...
class ConvertFromInts(object):
# ...
class SubtractMeans(object):
# ...
class ToAbsoluteCoords(object):
# ...
class ToPercentCoords(object):
# ...
class Resize(object):
# ...
class RandomSaturation(object):
# ...
class RandomHue(object):
# ...
class RandomLightingNoise(object):
def __init__(self):
# ...
def __call__(self, image, boxes=None, labels=None):
# ...
class ConvertColor(object):
def __init__(self, current="BGR", transform="HSV"):
# ...
def __call__(self, image, boxes=None, labels=None):
# ...
class RandomContrast(object):
def __init__(self, lower=0.5, upper=1.5):
# ...
def __call__(self, image, boxes=None, labels=None):
# ...
class RandomBrightness(object):
def __init__(self, delta=32):
# ...
def __call__(self, image, boxes=None, labels=None):
# ...
class ToCV2Image(object):
# ...
class ToTensor(object):
# ...
class RandomSampleCrop(object):
def __init__(self):
# ...
def __call__(self, image, boxes=None, labels=None):
# ...
class Expand(object):
def __init__(self, mean):
# ...
def __call__(self, image, boxes, labels):
# ...
class RandomMirror(object):
def __call__(self, image, boxes, classes):
# ...
class SwapChannels(object):
# ...
class PhotometricDistort(object):
# ...
class SSDAugmentation(object):
# ...

class SSDAugmentation(object) 类

上面文件中的函数之间的具有明确的调用关系和顺序, 因此我们将根据这些函数之间的逻辑关系来对该文件进行解析. 由于在模型训练创建数据集时, 调用了 class SSDAugmentation(object) 类, 因此, 我们首先对该类进行解析, 该类的定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# ./utils/augmentation.py

class SSDAugmentation(object):
def __init__(self, size=300, mean=(104, 117, 123)):
self.mean = mean
self.size = size
self.augment = Compose([
ConvertFromInts(),
ToAbsoluteCoords(),
PhotometricDistort(),
Expand(self.mean),
RandomSampleCrop(),
RandomMirror(),
ToPercentCoords(),
Resize(self.size),
SubtractMeans()
])

def __call__(self, img, boxes, labels):
return self.augment(img, boxes, labels)

class Compose(object) 类

可以看到, 上面的类中主要包含三个成员, 一个是图片 BGR 的平均值 mean, 一个是图片的尺寸大小, 还有一个使用了本文件中的 class Compose(object), 并想该类传递了一个列表参数, 列表内的元素类型实现了 __call__ 方法的类, 也就是可以直接当做函数进行调用, 下面我们先来看一下 class Compose(object) 类的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# ./utils/augmentation.py

class Compose(object):
# 该类用于将多个 transformations 函数组合起来

def __init__(self, transforms):
# transforms (list): 列表元素为各种 transformations
self.transforms = transforms

def __call__(self, img, boxes=None, labels=None):
# 按照列表中顺序依次执行各种 transformations, 最终返回三个对象
for t in self.transforms:
img, boxes, labels = t(img, boxes, labels)
return img, boxes, labels

可以看到, Compose 的作用实际上就是按照顺序不断调用列表内的类函数来对数据进行 transformations, 下面我们就根据默认调用的函数顺序对文件中的其他类进行解析.

class ConvertFromInts(object) 类

该类将 image 数据中的像素类型从整形变成浮点型.

1
2
3
4
5
# ./utils/augmentation.py
class ConvertFromInts(object):

def __call__(self, image, boxes=None, labels=None):
return image.astype(np.float32), boxes, labels

class ToAbsoluteCoords(object) 类

我们知道, 在进行目标检测时, 默认的生成的 boxes 的坐标值是按照图片的长宽比例来存储的, 这里为了方便我们后续的 transformations 的执行, 因此, 我们需要暂时将 boxes 的坐标切换成绝对值坐标. (事后会切换回来)

1
2
3
4
5
6
7
8
9
10
11
# ./utils/augmentation.py
class ToAbsoluteCoords(object):

def __call__(self, image, boxes, labels):
width, height, channels = image.shape
boxes[:, 0] *= width
boxes[:, 2] *= width
boxes[:, 1] *= height
boxes[:, 3] *= height

return image, boxes, labels

class PhotometricDistort(object) 类

该类会随机选择一些图片执行扭曲操作(distort), 可以看做是一种数据增广技术, 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# ./utils/augmentation.py
class PhotometricDistort(object):

def __init__(self):
# pd 是一个列表, 其中存放的多个 transformations 类.
# pd 将会作为 Compose 的参数使用
self.pd = [
RandomContrast(),
ConvertColor(transform='HSV'),
RandomSaturation(),
RandomHue(),
ConvertColor(current='HSV', transform='BGR'), # opencv
RandomContrast()
]
# 随机调节亮度
self.rand_brightness = RandomBrightness()
# 随机增加噪声
self.rand_light_noise = RandomLightingNoise()

def __call__(self, image, boxes, labels):
im = image.copy()
im, boxes, labels = self.rand_brightness(im, boxes, labels)
if random.randint(2): # 随机执行下面两者操作之一
distort = Compose(self.pd[:-1])
else:
distort = Compose(self.pd[1:])
im, boxes, labels = distort(im, boxes, labels)
return self.rand_light_noise(im, boxes, labels)

上面的类中包含了许多数据增广的方法, 下面我们会逐个进行介绍.

class RandomContrast(object) 类

令图片中所有像素的值都乘以一个介于 [lower, upper] 之间的随机系数. 该操作会随机执行.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# ./utils/augmentation.py
class RandomContrast(object):

def __init__(self, lower=0.5, upper=1.5):
self.lower = lower
self.upper = upper
assert self.upper >= self.upper
assert self.lower >= 0

def __call__(self, image, boxes=None, labels=None):
# 随机执行
if random.randint(2):
alpha = random.uniform(self.lower, self.upper)
image *= alpha

return image, boxes, labels

class ConvertColor(object) 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# ./utils/augmentation.py
class ConvertColor(object):

def __init__(self, current='BGR', transform='HSV'):
self.current = current
self.transform = transform

def __call__(self, image, boxes=None, labels=None):
if self.current == 'BGR' and self.transforms == 'HSV':
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
elif self.current == 'HSV' and self.transforms == 'BGR':
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
else:
raise NotImplementError
return image, boxes, labels

class RandomSaturation(object) 类

随机的改变图片的饱和度, 会给图片中的绿色通道的值乘以一个 [lower, upper] 之间的随机值

1
2
3
4
5
6
7
8
9
10
11
12
13
# ./utils/augmentation.py
class RandomSaturation(object):

def __init__(self, lower=0.5, upper=1.5):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower
assert self.lower >=0

def __call__(self, image, boxes=None, labels=None):
if random.randint(2):
image[:, :, 1] *= random.uniform(self.lower, self.upper)
return image, boxes, labels

class RandomHue(object) 类

随机改变图片的色调, 对蓝色通道值进行修改.

1
2
3
4
5
6
7
8
9
10
11
12
13
# ./utils/augmentation.py
class RandomHue(object):

def __init__(self, delta=18.0):
assert delta >= 0.0 and delta <= 360.0
self.delta = delta

def __call__(self, image, boxes=None, labels=None):
if random.randint(2):
image[:, :, 0] += random.uniform(-self.delta, self.delta)
image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0
image[:, :, 0][image[:, :, 0] < 0.0] += 360.0
return image, boxes, labels

class RandomBrightness(object) 类

随机调节图片的亮度, 给 BGR 三通道随机加上或减去一个值.

1
2
3
4
5
6
7
8
9
10
11
12
13
# ./utils/augmentation.py
class RandomBrightness(object):

def __init__(self, delta=32):
assert delta >= 0.0
assert delta <= 255.0
self.delta = delta

def __call__(self, image, boxes=None, labels=None):
if random.randint(2):
delta = random.uniform(-self.delta, self.delta)
image += delta
return image, boxes, labels

class RandomLightingNoise(object) 类

该类会随机给图片增加噪声, 随机调换 BGR 三通道的顺序. 在该类的使用了 class SwapChannels(object) 类来辅助完成目的功能, 代码会在下面一并给出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# ./utils/augmentation.py
class RandomLightingNoise(object):
def __init__(self):
self.perms = (
(0,1,2), (0,2,1),
(1,0,2), (1,2,0),
(2,0,1), (2,1,0)
)
def __call__(self, image, boxes=None, labels=None):
if random.randint(2):
swap = random.randint(len(self.perms))
shuffle = SwapChannels(swap)
image = shuffle(image)
return image, boxes, labels

class SwapChannels(object):
def __init__(self, swaps):
self.swaps = swaps

def __call__(self, image):
image = image[:, :, self.swaps]

return image, boxes, labels

class Expand(object)

该操作将创建一个数倍于原图的画板, 将图片随机放置在画板中的某部分, 其他部分的像素值权值填充成参数 mean 的值.

1
# ./utils/augmentation.py

class RandomSampleCrop(object)

对图片进行随机裁剪

1
# ./utils/augmentation.py

class RandomMirror(object)

将图片横向反转(就向镜子的功能一样)

1
# ./utils/augmentation.py

class ToPercentCoords(object)

将当前 numpy 数组内的图片的真实坐标转换成百分比坐标(相对于图片的尺寸而言)

1
# ./utils/augmentation.py

class Resize(object)

调用 cv2.resize() 函数将图片的大小进行缩放

1
# ./utils/augmentation.py

class SubtractMeans(object)

将图片中的像素值减去给定的平均值参数 meand

1
# ./utils/augmentation.py