JAX 是一个由 Google 开发的用于优化科学计算 Python 库:
- 它能够被视为 GPU 和 TPU 上运行的 NumPy,jax.numpy 提供了与 numpy 十分类似 API 接口。
- 它与 NumPy API 十分类似,简直任何能够用 numpy 实现的事件都能够用 jax.numpy 实现。
- 因为应用 XLA(一种减速线性代数计算的编译器)将 Python 和 JAX 代码 JIT 编译成优化的内核,能够在不同设施 (例如 gpu 和 tpu) 上运行。而优化的内核是为高吞吐量设施 (例如 gpu 和 tpu) 进行编译,它与主程序拆散但能够被主程序调用。JIT 编译能够用 jax.jit()触发。
- 它对主动微分有很好的反对,对机器学习钻研很有用。能够应用 jax.grad() 触发主动辨别。
- JAX 激励函数式编程,因为它是面向函数的。与 NumPy 数组不同,JAX 数组始终是不可变的。
- JAX 提供了一些在编写数字解决时十分有用的程序转换,例如 JIT . JAX()用于 JIT 编译和减速代码,JIT .grad()用于求导,以及 JIT .vmap()用于主动向量化或批处理。
- JAX 能够进行异步调度。所以须要调用 .block_until_ready() 以确保计算曾经理论产生。
JAX 应用 JIT 编译有两种形式:
- 主动:在执行 JAX 函数的库调用时,默认状况下 JIT 编译会在后盾进行。
- 手动:您能够应用 jax.jit() 手动申请对本人的 Python 函数进行 JIT 编译。
JAX 应用示例
咱们能够应用 pip 装置库。
pip install jax
导入须要的包,这里咱们也持续应用 NumPy,这样能够执行一些基准测试。
import jax
import jax.numpy as jnp
from jax import random
from jax import grad, jit
import numpy as np
key = random.PRNGKey(0)
与 import numpy as np 相似,咱们能够 import jax.numpy as jnp 并将代码中的所有 np 替换为 jnp。如果 NumPy 代码是用函数式编程格调编写的,那么新的 JAX 代码就能够间接应用。然而,如果有可用的 GPU,JAX 则能够间接应用。
JAX 中随机数的生成形式与 NumPy 不同。JAX 须要创立一个 jax.random.PRNGKey。咱们稍后会看到如何应用它。
咱们在 Google Colab 上做一个简略的基准测试,这样咱们就能够轻松拜访 GPU 和 TPU。咱们首先初始化一个蕴含 25M 元素的随机矩阵,而后将其乘以它的转置。应用针对 CPU 优化的 NumPy,矩阵乘法均匀须要 1.61 秒。
# runs on CPU - numpy
size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
# 1 loop, best of 5: 1.61 s per loop
在 CPU 上应用 JAX 执行雷同的操作均匀须要大概 3.49 秒。
# runs on CPU - JAX
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
# 1 loop, best of 5: 3.49 s per loop
在 CPU 上运行时,JAX 通常比 NumPy 慢,因为 NumPy 已针对 CPU 进行了十分多的优化。然而,当应用加速器时这种状况会发生变化,所以让咱们尝试应用 GPU 进行矩阵乘法。
# runs on GPU
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time
# 1. CPU times: user 102 µs, sys: 42 µs, total: 144 µs
# Wall time: 155 µs
# 2. CPU times: user 1.3 s, sys: 195 ms, total: 1.5 s
# Wall time: 2.16 s
# 3. 10 loops, best of 5: 68.9 ms per loop
从示例中能够看出,要进行偏心的基准比拟,咱们须要应用 JAX 测量不同的步骤:
设施传输工夫:将矩阵传输到 GPU 所通过的工夫。耗时 0.155 毫秒。编译工夫:JIT 编译通过的工夫。耗时 2.16 秒。运行工夫:无效的代码运行工夫。耗时 68.9 毫秒。
在 GPU 上应用 JAX 进行单个矩阵乘法的总耗时约为 2.23 秒,高于 NumPy 的总工夫 1.61 秒。然而对于每个额定的矩阵乘法,JAX 只须要 68.9 毫秒,而 NumPy 须要 1.61 秒,快了 22 倍多!因而,如果屡次执行线性代数运算,那么应用 JAX 是有意义的。
让咱们测试应用 TPU 进行矩阵乘法。
# runs on TPU
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time
# 1. CPU times: user 131 µs, sys: 72 µs, total: 203 µs
# Wall time: 164 µs
# 2. CPU times: user 190 ms, sys: 302 ms, total: 492 ms
# Wall time: 837 ms
# 3. 100 loops, best of 5: 16.5 ms per loop
疏忽设施传输工夫和编译工夫,每个矩阵乘法均匀须要 16.5 毫秒:GPU 相比快了 4 倍,与 CPU 的 NumPy 相比快了 88 倍。须要阐明的是,当乘以不同大小的矩阵时,取得雷同的减速成果也不同:相乘的矩阵越大,GPU 能够优化操作的越多,减速也越大。
为了在 Google Colab 上复制上述基准,须要运行以下代码让 JAX 晓得有可用的 TPU。
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
让咱们看看 XLA 编译器。
XLA
XLA 是 JAX(和其余库,例如 TensorFlow,TPU 的 Pytorch)应用的线性代数的编译器,它通过创立自定义优化内核来保障最快的在程序中运行线性代数运算。XLA 最大的益处是能够让咱们在利用中自定义内核,该局部应用线性代数运算,以便它能够进行最多的优化。
XLA 最重要的优化是交融,即能够在同一个内核中进行多个线性代数运算,将两头输入保留到 GPU 寄存器中,而不将它们具体化到内存中。这能够显着减少咱们的“计算强度”,即所做的工作量与负载和存储数量的比例。交融还能够让咱们齐全省略仅在内存中 shuffle 的操作(例如 reshape)。
上面咱们看看如何应用 XLA 和 jax.jit 手动触发 JIT 编译。
应用 jax.jit 进行即时编译
这里有一些新的基准来测试 jax.jit 的性能。咱们定义了两个实现 SELU(Scaled Exponential Linear Unit)的函数:一个应用 NumPy,一个应用 JAX。临时先不思考 jax.jitat
def selu_np(x, alpha=1.67, lmbda=1.05):
return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)
def selu_jax(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
而后,咱们应用 NumPy 在 1M 个元素的向量上运行它。
# runs on the CPU - numpy
x = np.random.normal(size=(1000000,)).astype(np.float32)
%timeit selu_np(x)
# 100 loops, best of 5: 7.6 ms per loop
均匀须要 7.6 毫秒。当初让咱们在 CPU 上应用 JAX。
# runs on the CPU - JAX
x = random.normal(key, (1000000,))
%time selu_jax(x).block_until_ready() # 1. measure JAX compilation time
%timeit selu_jax(x).block_until_ready() # 2. measure JAX runtime
# 1. CPU times: user 124 ms, sys: 5.01 ms, total: 129 ms
# Wall time: 124 ms
# 2. 100 loops, best of 5: 4.8 ms per loop
当初均匀须要 4.8 毫秒,在这种状况下比 NumPy 快。下一个测试是在 GPU 上应用 JAX。
# runs on the GPU
x = random.normal(key, (1000000,))
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time selu_jax(x_jax).block_until_ready() # 2. measure JAX compilation time
%timeit selu_jax(x_jax).block_until_ready() # 3. measure JAX runtime
# 1. CPU times: user 103 µs, sys: 0 ns, total: 103 µs
# Wall time: 109 µs
# 2. CPU times: user 148 ms, sys: 9.09 ms, total: 157 ms
# Wall time: 447 ms
# 3. 1000 loops, best of 5: 1.21 ms per loop
函数运行工夫为 1.21 毫秒。上面咱们用 jax.jit 测试它,触发 JIT 编译器应用 XLA 将 SELU 函数编译到优化的 GPU 内核中,同时优化函数外部的所有操作。
# runs on the GPU
x = random.normal(key, (1000000,))
selu_jax_jit = jit(selu_jax)
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time selu_jax_jit(x_jax).block_until_ready() # 2. measure JAX compilation time
%timeit selu_jax_jit(x_jax).block_until_ready() # 3. measure JAX runtime
# 1. CPU times: user 70 µs, sys: 28 µs, total: 98 µs
# Wall time: 104 µs
# 2. CPU times: user 66.6 ms, sys: 1.18 ms, total: 67.8 ms
# Wall time: 122 ms
# 3. 10000 loops, best of 5: 130 µs per loop
应用编译内核,函数运行工夫为 0.13 毫秒!
让咱们回顾一下不同的运行工夫:
- CPU 上的 NumPy:7.6 毫秒。
- CPU 上的 JAX:4.8 毫秒(x1.58 减速)。
- 没有 JIT 的 GPU 上的 JAX:1.21 毫秒(x6.28 减速)。
- 带有 JIT 的 GPU 上的 JAX:0.13 毫秒(x58.46 减速)。
应用 JIT 编译防止从 GPU 寄存器中挪动数据这样给咱们带来了十分大的减速。一般来说在不同类型的内存之间挪动数据与代码执行相比十分慢,因而在理论应用时应该尽量避免!
将 SELU 函数利用于不同大小的向量时,您可能会取得不同的后果。矢量越大,加速器越能优化操作,减速也越大。
除了执行 selu_jax_jit = jit(selu_jax) 之外,还能够应用 @jit 装璜器对函数进行 JIT 编译,如下所示。
@jit
def selu_jax_jit(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
JIT 编译能够减速,为什么咱们不能全副都这样做呢?因为并非所有代码都能够 JIT 编译,JIT 要求数组形态是动态的并且在编译时已知。另外就是引入 jax.jit 也会带来一些开销。因而通常只有编译的函数比较复杂并且须要屡次运行能力节省时间。然而这在机器学习中很常见,例如咱们倾编译一个大而简单的模型,而后运行它进行数百万次训练、损失函数和指标的计算。
应用 jax.grad 主动微分
另一个 JAX 转换是应用 jit.grad() 函数的主动微分。
借助 Autograd,JAX 能够主动对原生 Python 和 NumPy 代码进行微分。并且反对 Python 的大部分个性,包含循环、if、递归和闭包。
上面看看一个带有 jit.grad() 的代码示例,咱们计算一个自定义的蕴含 JAX 函数的 Python 函数的导数
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
# [0.25, 0.19661197, 0.10499357]
总结
在本文中,咱们理解了 JAX 是什么,并理解了它的一些基本概念:NumPy 接口、JIT 编译、XLA、优化内核、程序转换、主动微分和函数式编程。在 JAX 之上,开源社区为机器学习构建了更多高级库,例如 Flax 和 Haiku。有趣味的能够搜寻查看。
https://avoid.overfit.cn/post/589106b6f0a0431480a42a1bf399e81e
作者:Fabio Chiusano