关于深度学习:机器学习洞察-JAX机器学习领域的新面孔

88次阅读

共计 4412 个字符,预计需要花费 12 分钟才能阅读完成。

在之前的《机器学习洞察》系列文章中,咱们别离针对于多模态机器学习和分布式训练、无服务器推理进行了解读,本文将为您重点介绍 JAX 的倒退并分析其演变和动机。上面,就让咱们来认识一下 JAX 这一新崛起的深度学习框架——

亚马逊云科技开发者社区为开发者们提供寰球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、流动与比赛等。帮忙中国开发者对接世界最前沿技术,观点,和我的项目,并将中国优良开发者或技术举荐给寰球云社区。如果你还没有关注 / 珍藏,看到这里请肯定不要匆匆划过,点这里让它成为你的技术宝库!

开源机器学习框架的演进

从这张 GitHub Star 趋势图能够看到,自 2019 年 JAX 呈现到现在放弃着一个向上的抛物线走势。

在考查一个开源机器学习框架时,例如开发者熟知的 PyTorch, TensorFlow, MXNet 等,往往会从反对模型的广泛性、部署的成熟性、生态系统的丰富性来对它做一个评估:包含是否反对 Hugging Face 等支流模型,以及其框架相干钻研论文的数量,还有它可提供复现代码的论文数量等等。

JAX 的源起

  • 为什么 Eager 模式是在 TensorFlow 1.4 版本之后引入的?
  • Eager 模式在 TensorFlow 2.0 之后变成了一个默认的执行模式,和原有的 Graph 模式的区别是什么?

回归并理清这些历史问题有助于开发者理解机器学习的演变逻辑,并理解 JAX 是如何汲取之前的教训,帮忙开发者更不便地实际深度学习或机器学习利用。

Eager 模式 V.S. Graph 模式

在 TF 引进了 Eager 模式之后,它会采纳更直观的界面,应用天然的 Python 代码和数据结构,而且享受更加便携的调试,在 Eager 模式中能够通过间接调用操作来检查和测试模型,而之前 Graph 这种模式有点相似于 C 和 C++,它的编程是写好程序之后要先进行编译能力运行。

Eager 模式有天然管制的流程,应用 Python 而不是图控制流,以及反对 GPU 和 TPU 的减速。做为开发者,咱们心愿能够主观地对待不同的框架,而不是比拟他们的优劣。值得思考的一个问题是:通过理解 TF 的 Eager 模式对于 Graph 模式的改良,它的改良逻辑和思路在 JAX 中都有身影。

什么是 JAX

JAX 作为当初越来越风行的库,是一种相似于 NumPy(应用 Python 开源的数值计算扩大库)的轻量级用于阵列的计算。JAX 最开始的设计不仅仅是为了深度学习而设计的,深度学习只是它的一小部分,它提供了编写 NumPy 程序的能力,这些程序能够应用 GPU/TPU 主动拆分和减速。

JAX 用于基于阵列的计算时,开发者无需批改代码就能够在 CPU/GPU/ASIC 上同时运行,并反对原生 Python 和 NumPy 函数的四种可组合函数转换:

  • 主动微分 (Autodiff)
  • 即时编译 (JIT compilation)
  • 主动向量化 (Vectorization)
  • 代码并行化 (Parallelization)

JAX 初体验

咱们能够通过上面这个简略的测试比照 JAX 和 NumPy 的计算性能。

输出一个 100 X 100 的二维数组 X,选取 ml.g4dn.12xlarge 计算实例通过 NumPy 和 JAX 别离对矩阵的前三次幂求和:

def fn(x):
  return x + x*x + x*x*x
x = np.random.randn(10000, 10000).astype(dtype='float32')
%timeit -n5 fn(x)

436 ms ± 206 µs per loop (mean ± std. dev. of 7 runs, 5 loops each)

咱们发现此计算大概须要 436 毫秒。接下来,咱们应用 JAX 实现以下计算:

jax_fn = jit(fn)
x = jnp.array(x)
%timeit jax_fn(x).block_until_ready()
3.67 ms ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

JAX 仅在 3.67 毫秒内执行此计算,比 NumPy 快 118 倍以上。可见,JAX 有可能比 NumPy 快几个数量级(留神,JAX 应用 TPU 而 NumPy 正在应用 CPU)。

* 以上为集体测试后果,非官方提供的数据,仅供钻研参考

比照测试后果可得,NumPy 实现计算须要 436 毫秒,而 JAX 仅须要 3.67 毫秒,计算速度相差 100 多倍。这个测试也阐明了为什么很多开发者对它的性能拍案叫绝。

JAX 的动机分析

咱们心愿通过答复这个问题来解读 JAX 的动机:

如何应用 Python 从头开始实现高性能和可扩大的深度神经网络?

在 NumPy 中创立深度学习零碎

通常,Python 程序员会从 NumPy 之类的货色开始,因为它是一种相熟的、基于数组的数据处理语言,在 Python 社区中曾经应用了几十年。如果你想在 NumPy 中创立深度学习零碎,你能够从预测办法开始。

这里能够用一个具体的例子阐明问题,从 NumPy 上的深度学习的场景说起:

上述代码展现了订阅一个前馈的神经网络,它执行了一系列的点积和激活函数,而后将输出转化为某种能够学习的输入。一旦定义了这样的一个模型,接下来须要做就是要定义损失函数,这个函数将为你提供正在尝试优化的那些指标,来适应最佳的机器学习模型。例如以上代码的损失函数是以均方误差损失函数 MSE 为例。

当初咱们来剖析下:在深度学习场景应用 NumPy 还短少什么?

硬件加速 (GPU/TPU)

主动微分 (autodiff) 疾速优化

增加编译 (Compilation) 交融操作

向量化操作批处理 (batching)

大型数据集并行化 (Parallelization)

1)硬件加速 (GPU/TPU):首先深度学习须要大量的计算,咱们想在减速的硬件上运行它。所以咱们想在 GPU 和 TPU/ASIC 上运行这个模型,这对于经典的 NumPy 来说有点艰难;

2)主动微分 (autodiff) 疾速优化:接下来咱们想要做主动微分,这样就能够无效地拟合这个损失函数,而不用本人来实现数值微分;

3)而后咱们须要增加编译 (Compilation):这样你就能够将这些操作交融在一起,使它们更加高效;

4)向量化操作批处理 (Batching):另外,当咱们编写了某些函数后,可能心愿将其利用于多个数据片段,而不再须要重写预测和损失函数来解决这些批量数据;

5)大型数据集并行化 (Parallelization):最初,如果咱们正在解决大型数据集,会心愿可能反对跨多个 cores 或多台 machines 做并行化操作。

JAX 的动机分析:XLA 和主动定位

JAX 十分重要的一个动机就是 XLA 和主动定位。让咱们来看看 JAX 能够做些什么,来填补后面剖析的在深度学习场景应用 NumPy 还短少的性能。

首先,用 jax.numpy 替换 numpy 导入模块。在许多状况下,jax.numpy 与经典的 NumPy 具备雷同的 API,但 jax.numpy 能够实现后面剖析时发现 NumPy 短少,然而在深度学习场景却十分须要的的货色。

JAX 能够通过 XLA 后端,来主动定位 CPU、GPU 和 TPU 或者 ASIC,以便疾速计算模型和算法。

JAX 动机分析:Autograd

第二个重要动机是 Autograd。开发者能够通过上面的代码调用 Autograd 版本:

通过 from jax import grad 模块,应用 Autograd 的更新版本,JAX 能够主动微分原生 Python 和 NumPy 函数。它能够解决 Python 性能的大子集,包含循环、Ifs、递归等,甚至能够承受导数的导数。

JAX 提供了一组可组合的变换,其中之一是 grad 变换

例子中,像 mse_loss 这样的损失函数,通过 grad (mse_loss) 将其转换为计算梯度的 Python 函数。

Autograd 的次要预期利用是基于梯度的优化。

无关更多信息,请查看 JAX 教程和示例:Https://github.com/hips/autograd

JAX 动机分析:vmap

在应用梯度函数时,开发者心愿将其利用于多个数据片段,而在 JAX 中,你不再须要重写预测和损失函数来解决这些批量数据。

如图中代码最初一行 (perexample_grads …) 所诠释的那样,如果你通过 vmap transform 传递它,这会主动向量化这个代码,这样就能够在多个批次中应用雷同的代码。

JAX 动机分析:jit

JAX 还有一个重要的组合函数——jit,开发者能够应用 jit transform 实现即时编译。

jit 联合后盾能够应用 XLA 后端编译器将操作交融在一起,来主动定位 CPU、GPU 和 TPU 或者 ASIC,减速计算模型和算法。

JAX 动机分析:pmap

最初,如果想并行化你的代码,有一个和 vmap 十分类似得转换叫 pmap。

通过代码运行 pmap,开发者可能本地定位系统中的多个内核或你有权拜访的 GPU、TPU 或 ASIC 集群。

这最终成为一个十分弱小的零碎,能够在没有太多额定代码的状况下构建咱们用相似于 NumPy 的相熟 API,做深度学习的疾速计算等工作负载

JAX 的要害设计思维

通过上述比照能够看到,JAX 不仅为开发者提供了和 NumPy 类似的 API,上述的五大函数转换组合也让 JAX 能够在不须要额定代码的状况下,帮忙开发者构建深度学习利用进行疾速计算。

这里的要害思维是:

1)首先,在 JAX 中,Python 代码被追溯到两头示意,JAX 晓得如何转换这个两头示意。

2)在下篇文章中咱们也将详细分析 JAX 的工作机制:同样的两头示意,通过容许 XLA 进行特定畛域 (CPU/GPU 等) 的编译,如何来瞄准不同的后端;

3)另外,JAX 还有基于 NumPy 和 SciPy 的面向用户的 API,如果开发者始终应用 Python 的技术栈,应该会对 JAX 感觉相当相熟;

4)最初,JAX 提供了 功能强大的变换:grad, git, vmap, pmap 等,来反对深度学习等计算,因而 JAX 能够做到之前 NumPy 代码无奈做到的事件。

通过后面的介绍,咱们能够看到,开发者相熟的 API 和语法以及四种弱小的转换组合让开发者更加喜爱 JAX,并让深度学习场景或者科学计算变得十分简便。
欢送回顾对于机器学习的往期文章,以及更多面向开发者的技术分享。请继续关注 Build On Cloud 微信公众号!

往期举荐

  • 机器学习洞察 | 开掘多模态数据机器学习的价值
  • 机器学习洞察 | 分布式训练让机器学习更加疾速精确
  • 机器学习洞察 | 降本增效,无服务器推理是怎么做到的?

作者
黄浩文
亚马逊云科技资深开发者布道师,专一于 AI/ML、Data Science 等。领有 20 多年电信、挪动互联网以及云计算等行业架构设计、技术及守业治理等丰盛教训,曾就任于 Microsoft、Sun Microsystems、中国电信等企业,专一为游戏、电商、媒体和广告等企业客户提供 AI/ML、数据分析和企业数字化转型等解决方案咨询服务。

文章起源:https://dev.amazoncloud.cn/column/article/63e33239e5e05b6ff89…

正文完
 0