关于ml:动手用-Java-训练深度学习模型

5次阅读

共计 6648 个字符,预计需要花费 17 分钟才能阅读完成。

很长时间以来,Java 始终是一个很受企业欢送的编程语言。得益于丰盛的生态以及欠缺保护的包和框架,Java 领有着宏大的开发者社区。只管深度学习利用的一直演进和落地,提供给 Java 开发者的框架和库却非常短缺。现今次要风行的深度学习模型都是用 Python 编译和训练的。对于 Java 开发者而言,如果要进军深度学习界,就须要重新学习并承受一门新的编程语言同时还要学习深度学习的简单常识。这使得大部分 Java 开发者学习和转型深度学习开发变得困难重重。

为了缩小 Java 开发者学习深度学习的老本,亚马逊云科技构建了 Deep Java Library (DJL),一个为 Java 开发者定制的开源深度学习框架。它为 Java 开发者对接支流深度学习框架提供了一个桥梁。DJL 同时对 Apache MXNet,PyTorch 和 TensorFlow 最新版本的反对,使得开发者能够轻松应用 Java 构建训练和推理工作。在这个文章中,咱们会尝试用 DJL 构建一个深度学习模型并用它训练 MNIST 手写数字辨认工作。

什么是深度学习?

在咱们正式开始之前,咱们先来理解一下机器学习和深度学习的基本概念。机器学习是一个通过利用统计学常识,将数据输出到计算机中进行训练并实现特定指标工作的过程。这种演绎学习的办法能够让计算机学习一些特色并进行一系列简单的工作,比方辨认照片中的物体。因为须要写简单的逻辑以及测量规范,这些工作在传统计算迷信畛域中很难实现。

深度学习是机器学习的一个分支,次要侧重于对于人工神经网络的开发。人工神经网络是通过钻研人脑如何学习和实现目标的过程中演绎而得出一套计算逻辑。它通过模仿局部人脑神经间信息传递的过程,从而实现各类简单的工作。深度学习中的“深度”来源于咱们会在人工神经网络中编织构建出许多层 (layer) 从而进一步对数据信息进行更深层的传导。深度学习技术利用范畴非常宽泛,当初被用来做指标检测,动作辨认,机器翻译,语意剖析等各类事实利用中。

📢想要理解更多亚马逊云科技最新技术公布和实际翻新,敬请关注在上海、北京、深圳三地举办的2021 亚马逊云科技中国峰会!点击图片报名吧~

训练 MNIST 手写数字辨认

我的项目配置

你能够用如下的 gradle 配置来引入依赖项。在这个案例中,咱们用 DJL 的 api 包 (外围 DJL 组件) 和 basicdataset 包 (DJL 数据集) 来构建神经网络和数据集。这个案例中咱们应用了 MXNet 作为深度学习引擎,所以咱们会引入 mxnet-engine 和 mxnet-native-auto 两个包。这个案例也能够运行在 PyTorch 引擎下,只须要替换成对应的软件包即可。

1plugins {
2    id 'java'
3}
4repositories {5    jcenter()
6}
7dependencies {8    implementation platform("ai.djl:bom:0.8.0")
9    implementation "ai.djl:api"
10    implementation "ai.djl:basicdataset"
11    // MXNet
12    runtimeOnly "ai.djl.mxnet:mxnet-engine"
13    runtimeOnly "ai.djl.mxnet:mxnet-native-auto"
14}

NDArray 和 NDManager

NDArray 是 DJL 存储数据结构和数学运算的根本构造。一个 NDArray 表白了一个定长的多维数组。NDArray 的应用办法相似于 Python 中的 numpy.ndarray。

NDManager 是 NDArray 的老板。它负责管理 NDArray 的产生和回收过程,这样能够帮忙咱们更好的对 Java 内存进行优化。每一个 NDArray 都会是由一个 NDManager 发明进去,同时它们会在 NDManager 敞开时一起敞开。NDManager 和 NDArray 都是由 Java 的 AutoClosable 构建,这样能够确保在运行完结时及时进行回收。想理解更多对于它们的用法和实际,请参阅这篇文章。

Model

在 DJL 中,训练和推理都是从 Model class 开始构建的。咱们在这里次要讲训练过程中的构建办法。上面咱们为 Model 创立一个新的指标。因为 Model 也是继承了 AutoClosable 构造体,咱们会用一个 try block 实现:

1try (Model model = Model.newInstance()) {
2    ...
3    // 主体训练代码
4    ...
5}

筹备数据

MNIST ((Modified National Institute of Standards and Technology) 数据库蕴含大量手写数字的图,通常被用来训练图像处理零碎。DJL 曾经将 MNIST 的数据集收录到了 basicdataset 数据集里,每个 MNIST 的图的大小是 28 x 28。如果你有本人的数据集,你也能够通过 DJL 数据集导入教程来导入数据集到你的训练任务中。

1int batchSize = 32; // 批大小
2Mnist trainingDataset = Mnist.builder()
3        .optUsage(Usage.TRAIN) // 训练集
4        .setSampling(batchSize, true)
5        .build();
6Mnist validationDataset = Mnist.builder()
7        .optUsage(Usage.TEST) // 验证集
8        .setSampling(batchSize, true)
9        .build();

这段代码别离制作出了训练和验证集。同时咱们也随机排列了数据集从而更好的训练。除了这些配置以外,你也能够增加对于图片的进一步解决,比方设置图片大小,对图片进行归一化等解决。

制作 model (建设 Block)

当你的数据集准备就绪后,咱们就能够构建神经网络了。在 DJL 中,神经网络是由 Block (代码块) 形成的。一个 Block 是一个具备多种神经网络个性的构造。它们能够代表 一个操作, 神经网络的一部分, 甚至是一个残缺的神经网络。而后 Block 能够程序执行或者并行。同时 Block 自身也能够带参数和子 Block。这种嵌套构造能够帮忙咱们结构一个简单但又不失维护性的神经网络。在训练过程中,每个 Block 中附带的参数会被实时更新,同时也包含它们的各个子 Block。这种递归更新的过程能够确保整个神经网络失去充沛训练。

当咱们构建这些 Block 的过程中,最简略的形式就是将它们一个一个的嵌套起来。间接应用筹备好 DJL 的 Block 品种,咱们就能够疾速制作出各类神经网络。

依据几种根本的神经网络工作模式,咱们提供了几种 Block 的变体。SequentialBlock 是为了应答程序执行每一个子 Block 结构而成的。它会将前一个子 Block 的输入作为下一个 Block 的输出 继续执行到底。与之对应的,是 ParallelBlock。ParallelBlock 用于将一个输出并行输出到每一个子 Block 中,同时将输入后果依据特定的合并方程合并起来。最初咱们说一下 LambdaBlock,它是帮忙用户进行疾速操作的一个 Block,其中并不具备任何参数,所以也没有任何局部在训练过程中更新。

咱们来尝试创立一个根本的多层感知机 (MLP) 神经网络吧。多层感知机是一个简略的前向型神经网络,它只蕴含了几个全连贯层 (LinearBlock)。那么构建这个网络,咱们能够间接应用 SequentialBlock。

1int input = 28 * 28; // 输出层大小
2int output = 10; // 输入层大小
3int[] hidden = new int[] {128, 64}; // 暗藏层大小
4SequentialBlock sequentialBlock = new SequentialBlock();
5sequentialBlock.add(Blocks.batchFlattenBlock(input));
6for (int hiddenSize : hidden) {
7    // 全连贯层
8    sequentialBlock.add(Linear.builder().setUnits(hiddenSize).build());
9    // 激活函数
10    sequentialBlock.add(activation);
11}
12sequentialBlock.add(Linear.builder().setUnits(output).build());

当然 DJL 也提供了间接就能够拿来用的 MLP Block :

1Block block = new Mlp(
2        Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
3        Mnist.NUM_CLASSES,
4        new int[] {128, 64});
5sequentialBlock.add(Linear.builder().setUnits(output).build());

训练

当咱们筹备好数据集和神经网络之后,就能够开始训练模型了。在深度学习中,个别会由上面几步来实现一个训练过程:

初始化: 咱们会对每一个 Block 的参数进行初始化,初始化每个参数的函数都是由 设定的 Initializer 决定的。

  • 前向流传: 这一步将输出数据在神经网络中逐层传递,而后产生输入数据。
  • 计算损失: 咱们会依据特定的损失函数 Loss 来计算输入和标记后果的偏差。
  • 反向流传: 在这一步中, 你能够利用损失反向求导算出每一个参数的梯度。
  • 更新权重: 咱们会依据抉择的优化器 (Optimizer) 更新每一个在 Block 上参数的值。

DJL 利用了 Trainer 构造体精简了整个过程。开发者只须要创立 Trainer 并指定对应的 Initializer , Loss 和 Optimizer 即可。这些参数都是由 TrainingConfig 设定的。上面咱们来看一下具体的参数设置:

  • TrainingListener: 这个是对训练过程设定的监听器。它能够实时反馈每个阶段的训练后果。这些后果能够用于记录训练过程或者帮忙 debug 神经网络训练过程中的问题。用户也能够定制本人的 TrainingListener 来对训练过程进行监听。
1DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
2    .addEvaluator(new Accuracy())
3    .addTrainingListeners(TrainingListener.Defaults.logging());
4try (Trainer trainer = model.newTrainer(config)){
5    // 训练代码
6}

当训练器产生后,咱们能够定义输出的 Shape。之后就能够调用 fit 函数来进行训练。fit 函数会对输出数据,训练多个 epoch 是并最终将后果存储在本地目录下。

1/*
2 * MNIST 蕴含 28x28 灰度图片并导入成 28 * 28 NDArray。3 * 第一个维度是批大小, 在这里咱们设置批大小为 1 用于初始化。4 */
5Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
6int numEpoch = 5;
7String outputDir = "/build/model";
8
9// 用输出初始化 trainer
10trainer.initialize(inputShape);
11
12TrainingUtils.fit(trainer, numEpoch, trainingSet, validateSet, outputDir, "mlp");

这就是训练过程的全副流程了!用 DJL 训练是不是还是很轻松的?之后看一下输入每一步的训练后果。如果你用了咱们默认的监听器,那么输入是相似于下图:

1[INFO] - Downloading libmxnet.dylib ...
2[INFO] - Training on: cpu().
3[INFO] - Load MXNet Engine Version 1.7.0 in 0.131 ms.
4Training:    100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24, speed: 1235.20 items/sec
5Validating:  100% |████████████████████████████████████████|
6[INFO] - Epoch 1 finished.
7[INFO] - Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24
8[INFO] - Validate: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
9Training:    100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10, speed: 2851.06 items/sec
10Validating:  100% |████████████████████████████████████████|
11[INFO] - Epoch 2 finished.NG [1m 41s]
12[INFO] - Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10
13[INFO] - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09
14[INFO] - train P50: 12.756 ms, P90: 21.044 ms
15[INFO] - forward P50: 0.375 ms, P90: 0.607 ms
16[INFO] - training-metrics P50: 0.021 ms, P90: 0.034 ms
17[INFO] - backward P50: 0.608 ms, P90: 0.973 ms
18[INFO] - step P50: 0.543 ms, P90: 0.869 ms
19[INFO] - epoch P50: 35.989 s, P90: 35.989 s

当训练后果实现后,咱们能够用方才的模型进行推理来辨认手写数字。如果方才的内容哪里有不是很分明的,能够参照上面两个链接间接尝试训练。

手写数据集训练:
http://docs.djl.ai/examples/d…

手写数据集推理:
http://docs.djl.ai/jupyter/tu…

总结

在这个文章中,咱们介绍了深度学习的基本概念,同时还有如何优雅的利用 DJL 构建深度学习模型并进行训练。DJL 也提供了更加多样的数据集和神经网络。如果有趣味学习深度学习,能够参阅咱们的 D2L Java 书(https://d2l.djl.ai/)。

参考链接

DJL 官网: 
https://djl.ai/

知乎专栏——DJL 深度学习库:
https://zhuanlan.zhihu.com/c_…

也欢送退出 DJL 的 slack 论坛:
https://app.slack.com/client/…

本篇作者


兰青
亚马逊云科技 AI 软件开发工程师
DJL 深度学习框架作者之一,Apache 软件基金会项目管理委员会成员。

正文完
 0