关于深度学习:MegEngine-端上训练让-AI-懂你更能保护你

64次阅读

共计 5562 个字符,预计需要花费 14 分钟才能阅读完成。

作者:Lenny | 旷视科技 MegEngine intern

刷购物 App 频频被“种草”、指纹识别一次比一次稳准快、美颜相机 get 你的爱好一键 P 图……

在智能手机上,利用 AI 算法进行个性化举荐能大幅度晋升用户的体验。然而,想让 AI 更懂你,很多利用都须要将用户数据进行模型训练,饭馆举荐背地是举荐零碎、指纹识别是利用过往数据主动优化模型、聪慧的美颜相机背地是对用户行为的剖析。

在这种状况下,如何让 AI 算法更精准地了解用户爱好又能保障用户数据安全呢?一个直观的想法就是间接在手机上进行模型训练,这样既防止了数据传输可能带来的泄露危险,又能一直晋升模型性能。MegEngine 既能够在 GPU 上进行训练,又能够在挪动设施上进行推理,那两者联合一下,是不是能够在挪动设施上进行训练呢?答案是必定的。

那么接下来,就来看一下如何在 MegEngine 外面进行端上训练吧~

依然是老规矩,拿 Mnist 数据集来进行试手,模型选用 LeNet。在咱们的内部测试中,调用端上训练接口的代码能够间接在手机上运行,并且成果和通用的 Python 训练接口齐全对齐。

回顾在 Pytorch、Tensorflow 等框架建设训练流程时候做的事件,咱们能够发现次要包含:

  1. 搭建模型;
  2. 增加 Loss 与 Optimizer;
  3. 导入数据集;
  4. 设置学习率、训练轮数等超参数并训练。

搭建模型

模型的搭建其实是结构前向计算图的一个过程,通过调用算子,获取与输出绝对应的输入。

从 LeNet 的模型构造容易得悉,咱们须要调用 2 次卷积算子,2 次池化算子,1 次 Flatten 算子,2 次矩阵乘算子,以及若干次四则运算的算子。

在 MegEngine 中,算子只是负责执行运算的一个“黑盒子”,咱们须要提前设置好参数,而后将参数与数据一起“喂”给算子。如下图所示,数据永远是逐层进行传递的,且其 Layout 会被主动计算,而参数则须要咱们手动进行设置。

对于 LeNet 这种前馈神经网络,咱们只须要将后面算子的输入与下一组参数链接到下一个算子,就能够将计算过程连接起来。

因为此处代码比拟简短,这里给出一个简化版的代码示例。能够看出,其实和调用通用的 Python 接口写法差异不大,甚至是一一对应的,比方 opr::Convolution 对应 nn.Conv2dopr::MatrixMul 对应nn.Linear,只是因为 C++ 语言个性和 Python 不同,所以写起来会有一些差别。

SymbolVar symbol_input =
           opr::Host2DeviceCopy::make(*graph, m_input); // 初始化输出数据

SymbolVar symbol_conv =
        opr::Convolution::make(symbol_input, symbol_conv_weight, conv_param); // symbol_weighs[0] 即咱们提前设置好的卷积 filter 权重
symbol_conv = opr::relu(symbol_conv + symbol_conv_bias); // 加偏置之后激活
SymbolVar symbol_maxpool =
        opr::Pooling::make(symbol_conv, pooling_param)
                .reshape({batchsize, fc_shape[0]}); // 池化之后进行展平

SymbolVar symbol_fc =
        opr::MatrixMul::make(symbol_maxpool, symbol_fc_weight) +
        symbol_fc_bias;
symbol_fc1= opr::relu(symbol_fc); // 通过矩阵乘运算结构全连贯层

通过这种形式,咱们即能够将算子、数据与参数进行组合,构建出咱们须要的前向计算图。

调用 Loss 与 Optimizer

当初 MegEngine 中曾经在 C++ 层面对 Loss 和 Optimizer 进行了封装,上面咱们以 Mnist 数据集训练中的穿插熵损失以及 SGD 优化器为例解说。

在 MegEngine 中,所有推理与训练实际上都是在一张计算图上进行,而 Loss 与 Optimizer 实质上不过是将结构计算图的一部分工作封装了起来以供用户间接调用,而无需反复“造轮子”。例如,咱们最相熟的均方误差中,实际上是调用一次减法算子之后再调用一次乘方算子。

$$ MSE\,\,=\,\,\left(y-y’’ \right) ^2 $$

明确了这一点之后,咱们只须要持续上一步,在咱们的模型输入前面调用 Loss 的 API 并进行拼接就能够,代码非常简单,和 Pytorch 中训练十分相似。

CrossEntopyLoss loss_func; // 先定义一个损失函数的实例,这里选取穿插熵损失
SymbolVar symbol_loss = loss_func(symbol_fc, symbol_label); // 将模型输入与标签作为输出,调用损失函数

这时,咱们失去的 symbol_loss 就是咱们训练过程中的损失。

与调用 Loss API 相似,咱们也能够很轻松地调用优化器插入到已有计算图中。

SGD optimizer = SGD(0.01f, 5e-4f, .9f); // 实例化 SGD 优化器并设置参数
 
SymbolVarArray symbol_updates =
        optimizer.make_multiple(symbol_weights, symbol_grads, graph); // 将 Optimizer 插入到计算图中

这样一来,在反向流传之后,梯度就会被 Optimizer 进行解决并更新模型参数。

导入数据集

既然模型参数是咱们手动定义,那必定会留神到一个问题就是咱们的数据集怎么转化成参加计算图计算的数据呢?

这个当然 MegEngine 曾经筹备好了方法,能够通过继承一个接口并实现其中的 get_itemsize办法,并将这个类的实例输出到 DataLoader 中,那么就能够实现数据集的转换啦~

咱们要继承的接口定义如下。咦,这里平时用 Pytorch 的小伙伴必定曾经闻到了相熟的滋味。

class IDataView {
public:
    virtual DataPair get_item(int idx) = 0;
    virtual size_t size() = 0;
    virtual ~IDataView() = default;};

话不多说间接上一个示例,这里只示意如何继承接口并失去 DataLoader,如果有趣味看具体实现的小伙伴能够去关注 MegEngine~

class MnistDataset : public IDataView {
public:
    MnistDataset(std::string dir_name); // 初始化数据集,指定数据集寄存门路
    void load_data(Mode mode, std::string dir_name); // 读取 Mnist 数据集,存到 dataset 列表中。DataPair get_item(int idx); // 实现接口
    size_t size(); // 实现接口
 
protected:
    std::vector<DataPair> dataset;
};

// 实例化下面定义的数据集类
auto train_dataset = std::make_shared<MnistDataset>(dataset_dir);
// 用这个实例来获取对应的 DataLoader
auto train_dataloader =
        DataLoader(train_dataset, batchsize);

训练

既然实现了各个步骤,那么接下来的事件就是让训练跑起来~ 这里也是给出简略的伪代码示例。唔……这里应用 Pytorch 的小伙伴看了也会感到十分相熟,也就是循环每个 epoch,每个 epoch 中又循环每组数据与标签,不同的是在这里咱们不须要在循环中调用 Loss 与 Optimizer,因为后面曾经结构好了残缺的计算图,这里只须要执行咱们编译后的计算图即可。

func = graph->compile(); // 编译计算图
 
for (int epoch = 0; epoch < epochs; epoch++) {for (size_t i = 0; i < train_dataloader.size(); i++) {data = train_dataloader.next(); // 从 DataLoader 中获取数据
 
        func->execute(); // 执行计算图}
}

通过我的以身试法 (x),发现在端上训练能够达到用 Pytorch 以及 MegEngine 的 Python 训练接口训练的雷同准确率~ 到这里咱们的验证即获胜利!

看到这里,置信你曾经理解了如何在 MegEngine 中进行端上训练了,那么 Loss 和 Optimizer 又到底是什么样的接口呢?

Loss 与 Optimizer 的封装

有的时候,咱们会遇到须要封装本人须要的 Loss 和 Optimizer 的状况,这时候理解 Loss 和 Optimizer 的 API 就显得比拟重要。

Loss 的接口非常简略,能够归结为如下所示:

class ILoss {
public:
    virtual mgb::SymbolVar operator()(mgb::SymbolVar symbol_pred,
                                      mgb::SymbolVar symol_label) = 0;
    virtual ~ILoss() = default;};

只有输出预测值和标签值两个计算节点,能对应输入一个计算节点即可,这里仔细的小伙伴可能曾经留神到 SymbolVar 就是后面构建前向计算图的时候用到的类,这也是为什么说 Loss 的实质就是帮忙你在计算图中插入一段计算过程。

Optimizer 的接口也很扼要,能够归结为上面的代码:

class IOptimizer {
public:
    virtual mgb::SymbolVarArray make_multiple(
            mgb::SymbolVarArray symbol_weights,
            mgb::SymbolVarArray symbol_grads,
            std::shared_ptr<mgb::cg::ComputingGraph> graph) = 0;
    virtual mgb::SymbolVar make(
            mgb::SymbolVar symbol_weight, mgb::SymbolVar symbol_grad,
            std::shared_ptr<mgb::cg::ComputingGraph> graph) = 0;
    virtual ~IOptimizer() = default;};
 
class Optimizer : public IOptimizer {
public:
    mgb::SymbolVarArray make_multiple(
            mgb::SymbolVarArray symbol_weights,
            mgb::SymbolVarArray symbol_grads,
            std::shared_ptr<mgb::cg::ComputingGraph> graph); // 留神这里并不是纯虚函数
    virtual mgb::SymbolVar make(
            mgb::SymbolVar symbol_weight, mgb::SymbolVar symbol_grad,
            std::shared_ptr<mgb::cg::ComputingGraph> graph) = 0;
    virtual ~Optimizer() = default;};

与 Loss 相似,这里咱们也是输出计算节点,而后对应输入一个计算节点。值得注意的是 Optimizer 分为了两局部,一部分是纯正的接口 IOptimizer,另一部分是继承了这个接口的抽象类Optimizer。事实上,因为很多状况下,咱们习惯于用一个数组或列表来寄存咱们的参数与失去的梯度,这时候因为动态语言的限度,不能间接将这种状况归并到繁多输出的状况中,然而实际上只有咱们实现了Make 接口,输出是数组的状况也天然会失去解决。然而思考到接口与类该当进行拆散的理念,这里进行了抽离,变成了一个接口、一个抽象类,且抽象类中蕴含了对数组输出的状况 (make_multiple接口)的默认实现。

假使须要增加一个自定义的 Loss 或 Optimizer,只须要继承相应的接口或抽象类并实现即可。

例如对均方误差 MSE 的实现:

mgb::SymbolVar MSELoss::operator()(mgb::SymbolVar symbol_pred, mgb::SymbolVar symol_label) {return opr::pow(symbol_pred - symol_label, symbol_pred.make_scalar(2));
}

总结与瞻望

看到这里,兴许你会充斥好奇,兴许你会一脸厌弃……

端上训练作为一个尚在摸索中的方向,当初确实和已有的训练、推理框架没法比拟,但 MegEngine 提供端上训练的性能会在你须要的时候为你提供一种抉择。在这样一个手机越来越占据人们生存的时代,以及人们对服务质量的需要一直进步的时代,想必端上训练会有用武之地。

以后 MegEngine 端上训练的次要问题与下一步可能的改良点有:

  • 模型的构建过程以后比拟原始,能够进一步的封装出相似 nn.module 的模块。
  • 有时候手里曾经有了带有计算图信息的某个权重文件,不心愿再次搭建计算图,而是间接读取现有的计算图并插入训练过程,能够提供相似的 API
  • 在 C++ 侧进行数据的读取会比拟麻烦

欢送大家来尝试应用 MegEngine 搭建端上训练利用,也欢送大家能指出以后 MegEngine 中端上训练存在的不足以便咱们改良,也能够来提 PR 一起解决问题~

MegEngine cpp Training Example

GitHub:旷视天元 MegEngine

欢送退出 MegEngine 技术交换 QQ 群:1029741705

正文完
 0