JAX 是机器学习 (ML) 畛域的新生力量,它无望使 ML 编程更加直观、结构化和简洁。
在机器学习畛域,大家可能对 TensorFlow 和 PyTorch 曾经耳熟能详,但除了这两个框架,一些新生力量也不容小觑,它就是谷歌推出的 JAX。很对研究者对其寄予厚望,心愿它能够取代 TensorFlow 等泛滥机器学习框架。
JAX 最后由谷歌大脑团队的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人发动。
目前,JAX 在 GitHub 上已累积 13.7K 星。
我的项目地址:https://github.com/google/jax
迅速倒退的 JAX
JAX 的前身是 Autograd,其借助 Autograd 的更新版本,并且联合了 XLA,可对 Python 程序与 NumPy 运算执行主动微分,反对循环、分支、递归、闭包函数求导,也能够求三阶导数;依赖于 XLA,JAX 能够在 GPU 和 TPU 上编译和运行 NumPy 程序;通过 grad,能够反对主动模式反向流传和正向流传,且二者能够任意组合成任何程序。
开发 JAX 的出发点是什么?说到这,就不得不提 NumPy。NumPy 是 Python 中的一个根底数值运算库,被宽泛应用。然而 numpy 不反对 GPU 或其余硬件加速器,也没有对反向流传的内置反对,此外,Python 自身的速度限制妨碍了 NumPy 应用,所以少有研究者在生产环境下间接用 numpy 训练或部署深度学习模型。
在此状况下,呈现了泛滥的深度学习框架,如 PyTorch、TensorFlow 等。然而 numpy 具备灵便、调试不便、API 稳固等独特的劣势。而 JAX 的次要出发点就是将 numpy 的以上劣势与硬件加速联合。
目前,基于 JAX 已有很多优良的开源我的项目,如谷歌的神经网络库团队开发了 Haiku,这是一个面向 Jax 的深度学习代码库,通过 Haiku,用户能够在 Jax 上进行面向对象开发;又比方 RLax,这是一个基于 Jax 的强化学习库,用户应用 RLax 就能进行 Q-learning 模型的搭建和训练;此外还包含基于 JAX 的深度学习库 JAXnet,该库一行代码就能定义计算图、可进行 GPU 减速。能够说,在过来几年中,JAX 掀起了深度学习钻研的风暴,推动了科学研究迅速倒退。
JAX 的装置
如何应用 JAX 呢?首先你须要在 Python 环境或 Google colab 中装置 JAX,应用 pip 进行装置:
$ pip install --upgrade jax jaxlib
留神,上述装置形式只是反对在 CPU 上运行,如果你想在 GPU 执行程序,首先你须要有 CUDA、cuDNN,而后运行以下命令(确保将 jaxlib 版本映射到 CUDA 版本):
$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
当初将 JAX 与 Numpy 一起导入:
import jax
import jax.numpy as jnp
import numpy as np
JAX 的一些个性
应用 grad() 函数主动微分:这对深度学习利用十分有用,这样就能够很容易地运行反向流传,上面为一个简略的二次函数并在点 1.0 上求导的示例:
from jax import grad
def f(x):
return 3*x**2 + 2*x + 5
def f_prime(x):
return 6*x +2
grad(f)(1.0)
# DeviceArray(8., dtype=float32)
f_prime(1.0)
# 8.0
jit(Just in time):为了利用 XLA 的弱小性能,必须将代码编译到 XLA 内核中。这就是 jit 发挥作用的中央。要应用 XLA 和 jit,用户能够应用 jit() 函数或 @jit 正文。
from jax import jit
x = np.random.rand(1000,1000)
y = jnp.array(x)
def f(x):
for _ in range(10):
x = 0.5*x + 0.1* jnp.sin(x)
return x
g = jit(f)
%timeit -n 5 -r 5 f(y).block_until_ready()
# 5 loops, best of 5: 10.8 ms per loop
%timeit -n 5 -r 5 g(y).block_until_ready()
# 5 loops, best of 5: 341 µs per loop
pmap:主动将计算调配到所有以后设施,并解决它们之间的所有通信。JAX 通过 pmap 转换反对大规模的数据并行,从而将单个处理器无奈解决的大数据进行解决。要查看可用设施,能够运行 jax.devices():
from jax import pmap
def f(x):
return jnp.sin(x) + x**2
f(np.arange(4))
#DeviceArray([0. , 1.841471 , 4.9092975, 9.14112], dtype=float32)
pmap(f)(np.arange(4))
#ShardedDeviceArray([0. , 1.841471 , 4.9092975, 9.14112], dtype=float32)
vmap:是一种函数转换,JAX 通过 vmap 变换提供了主动矢量化算法,大大简化了这种类型的计算,这使得钻研人员在解决新算法时无需再去解决批量化的问题。示例如下:
from jax import vmap
def f(x):
return jnp.square(x)
f(jnp.arange(10))
#DeviceArray([0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
vmap(f)(jnp.arange(10))
#DeviceArray([0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
TensorFlow vs PyTorch vs Jax
在深度学习畛域有几家巨头公司,他们所提出的框架被宽广研究者应用。比方谷歌的 TensorFlow、Facebook 的 PyTorch、微软的 CNTK、亚马逊 AWS 的 MXnet 等。
每种框架都有其优缺点,抉择的时候须要依据本身需要进行抉择。
咱们以 Python 中的 3 个次要深度学习框架——TensorFlow、PyTorch 和 Jax 为例进行比拟。这些框架尽管不同,但有两个共同点:
- 它们是开源的。这意味着如果库中存在谬误,使用者能够在 GitHub 中公布问题(并修复),此外你也能够在库中增加本人的性能;
- 因为全局解释器锁,Python 在外部运行迟缓。所以这些框架应用 C/C++ 作为后端来解决所有的计算和并行过程。
那么它们的不同体现在哪些方面呢?如下表所示,为 TensorFlow、PyTorch、JAX 三个框架的比拟。
TensorFlow
TensorFlow 由谷歌开发,最后版本可追溯到 2015 年开源的 TensorFlow0.1,之后倒退稳固,领有弱小的用户群体,成为最受欢迎的深度学习框架。然而用户在应用时,也裸露了 TensorFlow 毛病,例如 API 稳定性有余、动态计算图编程简单等缺点。因而在 TensorFlow2.0 版本,谷歌将 Keras 纳入进来,成为 tf.keras。
目前 TensorFlow 次要特点包含以下:
- 这是一个十分敌对的框架,高级 API-Keras 的可用性使得模型层定义、损失函数和模型创立变得非常容易;
- TensorFlow2.0 带有 Eager Execution(动态图机制),这使得该库更加用户敌对,并且是对以前版本的重大降级;
- Keras 这种高级接口有肯定的毛病,因为 TensorFlow 形象了许多底层机制(只是为了不便最终用户),这让钻研人员在解决模型方面的自由度更小;
- Tensorflow 提供了 TensorBoard,它实际上是 Tensorflow 可视化工具包。它容许研究者可视化损失函数、模型图、模型剖析等。
PyTorch
PyTorch(Python-Torch) 是来自 Facebook 的机器学习库。用 TensorFlow 还是 PyTorch?在一年前,这个问题毫无争议,研究者大部分会抉择 TensorFlow。但当初的状况大不一样了,应用 PyTorch 的研究者越来越多。PyTorch 的一些最重要的个性包含:
- 与 TensorFlow 不同,PyTorch 应用动静类型图,这意味着执行图是在运行中创立的。它容许咱们随时批改和查看图的内部结构;
- 除了用户敌对的高级 API 之外,PyTorch 还包含精心构建的低级 API,容许对机器学习模型进行越来越多的管制。咱们能够在训练期间对模型的前向和后向传递进行检查和批改输入。这被证实对于梯度裁剪和神经格调迁徙十分无效;
- PyTorch 容许用户扩大代码,能够轻松增加新的损失函数和用户定义的层。PyTorch 的 Autograd 模块实现了深度学习算法中的反向流传求导数,在 Tensor 类上的所有操作,Autograd 都能主动提供微分,简化了手动计算导数的简单过程;
- PyTorch 对数据并行和 GPU 的应用具备宽泛的反对;
- PyTorch 比 TensorFlow 更 Python 化。PyTorch 非常适合 Python 生态系统,它容许应用 Python 类调试器工具来调试 PyTorch 代码。
JAX
JAX 是来自 Google 的一个绝对较新的机器学习库。它更像是一个 autograd 库,能够辨别原生的 python 和 NumPy 代码。JAX 的一些个性次要包含:
- 正如官方网站所形容的那样,JAX 可能执行 Python+NumPy 程序的可组合转换:向量化、JIT 到 GPU/TPU 等等;
- 与 PyTorch 相比,JAX 最重要的方面是如何计算梯度。在 Torch 中,图是在前向传递期间创立的,梯度在后向传递期间计算,另一方面,在 JAX 中,计算示意为函数。在函数上应用 grad() 返回一个梯度函数,该函数间接计算给定输出的函数梯度;
- JAX 是一个 autograd 工具,不倡议独自应用。有各种基于 JAX 的机器学习库,其中值得注意的是 ObJax、Flax 和 Elegy。因为它们都应用雷同的外围并且接口只是 JAX 库的 wrapper,因而能够将它们放在同一个 bracket 下;
- Flax 最后是在 PyTorch 生态系统下开发的,更重视应用的灵活性。另一方面,Elegy 受 Keras 启发。ObJAX 次要是为以钻研为导向的目标而设计的,它更重视简略性和可了解性。
参考链接:
- https://www.askpython.com/pyt…
- https://jax.readthedocs.io/en…
- https://jax.readthedocs.io/en…
-
https://www.zhihu.com/questio…
开源前哨
日常分享热门、乏味和实用的开源我的项目。参加保护 10 万 + Star 的开源技术资源库,包含:Python、Java、C/C++、Go、JS、CSS、Node.js、PHP、.NET 等。