本篇教程主要介绍一些 Caffe2 里面基础概念, 以便帮助理解 Caffe2 模型的设计思路和基本原理.
在 Caffe2 中, 对于任意一个 operator
, 我们不仅仅需要提供 input, 同时还需要提供 weights 和 bias.(这是与 Caffe 的不同点之一)
Blobs and Workspace, Tensors
在 Caffe2 中, 数组的组织形式是 blobs
. 一个 blobs
可以看做是内存当前一块带有名字的数据. 绝大多数 blobs
都包含一个 tensor
, 在 Python 中, 会被转换成 numpy 数组来表示.
而 workspace
会存储所有的 blobs
, 下面的代码展示了如何将 blobs
添加到 workspace
中, 以及如何再次获取它们. workspace
会在你开始使用的时候对自身进行初始化.
1 | from caffe2.python import workspace, model_helper |
Nets and Operators
在 Caffe2 中, 基本模型的抽象形式为 net
. 一个 net
可以看成是一个图, 其中节点为各种 operators
, 这些 operators
会接受一系列 blobs
作为输入, 然后会输出一个或多个 blobs
.
在下面的代码块中, 我们将会创建一个简单的模型, 它包含以下三个部分:
- 一个全连接层(FC)
- 一个使用了 Softmax 的 Sigmoid 激活函数
- 交叉熵损失函数
我们利用 ModelHelper
来帮助创建模型, 它会创建以下两个互相联系的 nets
:
- 一个执行参数初始化(
ref.init_net
) - 一个执行训练逻辑(
ref.exec_net
)
1 | # Create the input data |
上面的代码首先在内存中创建了数据和标签的 blobs
. 数据的第一维代表 batch size, 即16. 许多 Caffe2 的 operators
都可以利用 ModelHelper
直接获取, 因此我们用它创建了一系列 operators
, 包括: FC
, Sigmoid
和 SoftmaxWithLoss
.ModelHelper
会创建两个 nets
: m.param_init_net
用于初始化指定的参数, 只需要执行一次即可. m.net
用于执行训练逻辑, 该过程对用户透明, 且是自动运行的.
网络的定义存储在一个 protobuf 结构当中(Google’s Protocal Buffer), 可以通过下面的方式来进行检查网络
1 | print(m.net.Proto()) |
输出如下:
1 | name: "my first net" |
通过下面的代码可以查看参数初始化网络
1 | print(m.param_init_net.Proto()) |
输出如下:
1 | name: "my first net_init" |
可以看到有两个 operators
, 它们会分别对 FC 的权重和偏置参数进行初始化.
Executing
现在, 既然我们已经定义好了用于网络训练的 operators
, 那么就可以利用下面的代码来训练我们的简单模型.
首先, 调用一次参数初始化网络:
1 | workspace.RunNetOnce(m.param_init_net) |
注意, 通常情况下, 上面的代码会将 param_init_net
的 protobuffer 结构传送给 C++ 运行时以供执行.
接下来, 创建真正用于训练的网络
1 | workspace.CreateNet(m.net) |
我们只需要创建该网络一次, 然后可以多次执行它
1 | workspace.CreateNet(m.net, overwrite=True) # 这里需要将 overwrite 设置为 True, 官方教程没有设值, 运行时会出现RuntimeError |
注意, 因为我们已经在 workspace
中定义过 net
了, 因此, 我们只需要传送 m.name
给 RunNet
函数即可, 而无需再次定义网络.
在训练迭代完成后, 可以通过 Fetch
来查看计算的结果
1 | print(workspace.Fetch("softmax")) |
Backward pass
上面的网络仅仅包含 foward pass, 因此它不会学到任何东西, 我们可以通过在每一个 operator
中添加一个梯度计算操作来实现 backward pass. 我们需要在调用 RunNetOnce()
之前在网络中插入下面的 operator
:
1 | m.AddGradientOperators([loss]) |
接着可以看看新的网络的结构
1 | print(m.net.Proto()) |
输出如下:
1 | name: "my first net" |
可以看到, 网络中增加了4个新的 operators
, 其输入为 backward 的计算结果, 输入为相应参数的梯度.