关于深度学习:对比PyTorchTensorFlowJAXTheano我发现都在关注两大问题

5次阅读

共计 7250 个字符,预计需要花费 19 分钟才能阅读完成。

作者|王益
OneFlow 社区编译
翻译|杨婷

最近,我在解决 PyTorch 分布式和 TorchRec 相干的工作,为此,我开始学习 PyTorch 2.0。在业余时间,我也在跟着 Alpa 作者学习 JAX 和 XLA。现在回顾这些技术,我发现它们的关注点仿佛都是如下两个问题:

蕴含主动求导和并行在内的函数转换,例如 vmap, pmap 和 pjit 等;
异构计算,CPU 负责控制流,GPU/TPU 负责张量计算和汇合通信。

本文档中的所有例子都反对在 Colab 中运行:

1

函数转换

“函数转换”意为将一个程序转变成另一个程序,最常见的例子是主动求导(autograd)。主动求导采纳用户编写的前向过程并创立后向过程,对于用户来说,编写主动求导通常都太过简单。函数转换的次要难点在于:在编写函数转换算法时以何种形式示意输出和输入过程。

Theano:显式地构建 IR

Theano 是最早的深度学习工具之一,也就是现在为人们所熟知的 Aesara 我的项目。Theano 有一个容许用户在内存中将 IR 构建为数据结构的 API,因而 Theano 可实现主动求导,并将后果输入为 Python 函数。

import aesara
from aesara import tensor as at

a = at.dscalar("a") # Define placeholders, which have no values.
b = at.dscalar("b")

c = a * b              # c now contains the IR of an expression.TT
dc = aesara.grad(c, a) # Convert the IR in c into another one, dc

f_dc = aesara.function([a, b], dc) # Convert the IR into a Python function,
assert f_dc(1.5, 2.5) == 2.5       # so we can call it.

TensorFlow 1.x:用于运行 IR 的虚拟机

TensorFlow 1.x 明确保留了构建 IR 的想法。若在 TensorFlow 中运行上述示例,后果不会有什么差异;但假使在 TensorFlow 1.x 中来运行,最大的差异在于:咱们不会将后向 IR 转换为 Python 函数,并应用 Python 解释器来运行。相同,咱们会在 TensorFlow runtime 中来运行。

import tensorflow.compat.v1 as tf # TensorFlow 1.x API
import numpy as np
tf.disable_eager_execution()

a = tf.placeholder(tf.float32, shape=())
b = tf.placeholder(tf.float32, shape=())

c = a * b
dc = tf.gradients(c, [a], stop_gradients=[a, b])

with tf.compat.v1.Session() as sess:  # TensorFlow has a runtime to execute the IR, 
  x = np.single(2)                    # so, no converting it into Python code. 
  y = np.single(3)     
  print(sess.run(dc, feed_dict={a:x, b:y}))

PyTorch 1.x:没有前向 IR

PyTorch 不会像 Theano 或 TensorFlow 那样将前向流传转换为 IR。反之,PyTorch 应用 Python 解释器来运行前向流传。这样做的弊病在于会在运行期间生成示意后向流传的 IR,咱们称之为 Eager 模式(动态图模式)。

import torch

a = torch.tensor(1.0, requires_grad=True) # These are not placeholders, but values.
b = torch.tensor(2.0)

c = a * b    # Evaluates c and derives the IR of the backward in c.grad_fn_.
c.backward() # Executes c.grad_fn_.
print(c.grad)

TensorFlow 2.x: 梯度带

TensorFlow 2.x 减少了一个像 PyTorch API 的 Eager 模式 API。此 API 追踪前向流传如何运行名为梯度带(GradientTape)的 IR。TensorFlow 2.x 能够从这个跟踪中找出后向流传。

import tensorflow as tf

a = tf.Variable(1.0) # Like PyTorch, these are values, not placehodlers. 
b = tf.Variable(2.0)

with tf.GradientTape() as tape:
  c = a * b
dcda = tape.gradient(c, a)
print(dcda)

JAX

JAX 不会向用户公开诸如梯度带等方面的低级别细节。简略说来,JAX 的思维形式为:将输出和输入都用 Python 函数来示意。

import jax 

a = 2.0
b = 3.0
jax.grad(jax.lax.mul)(a, b)  # Compute c = a * b w.r.t. a.  The result is b=3. 

jax.jit(jax.grad(jax.lax.mul))(a,b)

jax.experimental.pjit(jax.grad(jax.lax.mul), 
                      device_mesh(ntpus))(a,b)

对于想要本人编写的函数转换的高级用户,他们能够调用 make_jaxpr 等低级 API 来拜访 IR,称为 JAXPR。

jax.make_jaxpr(jax.lax.mul)(2.0, 3.0)  # Returns the IR representing jax.lax.mul(2,3)
jax.make_jaxpr(jax.grad(jax.lax.mul))(2.0, 3.0)  # Returns the IR of grad(mul)(2,3)

FuncTorch

FuncTorch 和 JAX 相似,都是基于 PyTorch 的函数转换。

import torch, functorch

a = torch.tensor([2.0])
b = torch.tensor([3.0])
functorch.grad(torch.dot)(a, b)

JAX 的 make_jaxpr 相似于 functorch 的make_fx

def f(a, b):
  return torch.dot(a, b) # Have to wrap the builtin function dot into f. # 必须将内置函数 dot 转换成 f.
  
print(functorch.make_fx(f)(a, b).code)
print(functorch.make_fx(functorch.grad(f))(a, b).code)

TensorFlow 2.x、JAX 和 functorch 都为前向传递构建了一个 IR,但 PyTorch Eager 模式没有。IR 不仅可用于主动求导,还可用于其余类型的函数转换。在下列例子中,functorch.compile.aot_function调用了回调函数 print_compile_fn 两次,别离用于前向和后向流传。

from functorch.compile import aot_function
import torch.fx as fx

def print_compile_fn(fx_module, args):
    print(fx_module)
    return fx_module
aot_fn = aot_function(torch.dot, print_compile_fn)
aot_fn(a, b)

2

高阶导数

PyTorch

import torch
from torch import autograd

x = torch.tensor(1., requires_grad = True)
y = 2*x**3 + 8

first_derivative = autograd.grad(y, x, create_graph=True)
print(first_derivative)

second_derivative = autograd.grad(first_derivative, x)
print(second_derivative)

TensorFlow 2.x

import tensorflow as tf

x = tf.Variable(1.0)

with tf.GradientTape() as outer_tape:
    with tf.GradientTape() as tape:
        y = 2*x**3 + 8
        dy_dx = tape.gradient(y, x)
        print(dy_dx)
    d2y_dx2 = outer_tape.gradient(dy_dx, x)
    print(d2y_dx2)

JAX

def f(a):
  return 2*a**3 + 8

print(jax.grad(f)(1.0))
print(jax.grad(jax.grad(f))(1.0))

3

动态控制流

动态控制流(dynamic control flows)有两个层级:在 CPU 上运行的粗粒度级别和在 GPU /TPU 上运行的细粒度级别。本局部次要介绍在 CPU 上运行的粗粒度级别的动态控制流。上面咱们将用 (if/else) 条件语句作为例子测验深度学习工具。

TensorFlow 1.x

在 TensorFlow 1.x 中,咱们须要将条件语句显式构建到 IR 中。此时条件语句是一个非凡的运算符 tf.cond

def f1(): return tf.multiply(a, 17)
def f2(): return tf.add(b, 23)
r = tf.cond(tf.less(a, b), f1, f2)

with tf.compat.v1.Session() as sess:  # TensorFlow has a runtime to execute the IR,
  print(sess.run(r, feed_dict={a:x, b:y}))

TensorFlow 2.x

TensorFlow 2.x 反对应用 tf.condtf.while_loop 显式构建控制流。此外,试验我的项目 google/tangent 中有 AutoGraph 性能,它能够将 Python 控制流转换为 tf.condtf.while_loop。此性能利用了 Python 解释器反对的函数和函数源代码。例如上面的 g 函数调用了 Python 的规范库将源代码解析为 AST,而后调用 SSA 表单来了解控制流。

def g(x, y):
    if tf.reduce_any(x < y):
        return tf.multiply(x, 17)
    return tf.add(y, 23)
    
converted_g = tf.autograph.to_graph(g)

import inspect
print(inspect.getsource(converted_g))

JAX

因为局部 Python 语法很简单,所以通过解析源代码来了解控制流就显得很艰难,这就导致 AutoGraph 常常出错。但如果这种办法很简略,那么 Python 开发者社区也不会在构建 Python 编译器时失败这么屡次了。正是因为有这种挑战的存在,必须要明确地将控制流构建到 IR 中。为此,JAX 提供了 jax.lax.condjax.lax.for_loop函数。

jax.lax.cond(a < b, lambda : a*17, lambda: b+23)

思考到这一点,你可能会感觉咱们能够应用递归算法。然而上面用于计算阶乘的递归无奈用 JAX 跟踪。

def factorial(r, x):
  return jax.lax.cond(x <= 1.0, lambda: r, lambda: factorial(r*x, x-1))
factorial(1.0, 3.0)

可能你还想调用 factorial 来计算 3!=6。但这会让递归深度超过最大值,因为递归不仅依赖于条件,还依赖于函数定义和调用。

PyTorch

PyTorch 最后是 Python-native。正如前文所说,因为多功能调度机制,gradvamp 的函数转换都是即时的。值得注意的是:

  1. 相比 Theano 和 TensorFlow 构建 IR 后的函数转换,即时函数转换效率更高。
  2. 在进行 gradvmap 时,JAX 也是即时函数转换。然而像 pamppjit等更简单的函数转换须要对整个计算过程进行概述,在这个过程中 IR 是必不可少的。

因为 IR 在 pmappjit 中的必要性,PyTorch 社区最近增加了torch.condpytorch/pytorch#83154

4

分布式计算

依据执行代码或 IR 的不同形式,在应用 Python 解释器或 runtime 时,有两种分布式计算办法。

Python-Native

Theano 和 PyTorch 采纳了 Python-native 分布式计算形式。这种分布式训练工作蕴含多个 Python 解释器过程。这导致呈现了以下后果。

  1. 打包和运行(Pack and run)。因为这些 Python 过程在不同的 host 上运行,因而咱们须要打包用户程序和依赖项,并将它们发送到这些 host 下来运行。始终以来 TorchX 负责了这个打包过程。它反对例如 Docker 和 torch.package 等各种打包格局,并且能够与各种集群管理器配合应用,如 Kubernetes 和 SLURM。
  2. 单程序多数据(SPMD)。因为将用户程序发送到各种 host 上要依赖于打包,与其余权重较轻的形式(如通过 RPC 发送代码)相比,这种形式不太灵便,因而,咱们通常只发送一个程序。当所有这些过程运行同一程序时,这个作业就变成了单程序多数据(SPMD)作业。

Python-native SPMD

上面是一个简略的 SPMD PyTorch 程序,咱们能够在雷同或不同的 host 上应用过程运行这个程序。在这个过程中,咱们只须要调用all_gather。真正的分布式训练程序会调用更高级别的 API,例如torch.nn.parallel.DistributedDataParalleltorchrec.DistributedModelParallel, 而后再调用低级 API,例如 all_gatherall_reduce

import os
import torch
from torch import distributed as dist

def main():
    use_gpu = torch.cuda.is_available()
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", "0"))
    device = torch.device(f"cuda:{local_rank}" if use_gpu else "cpu")
    dist.init_distributed(backend="nccl")

    lst = torch.tensor([local_rank + 100]).to(device)

    # placeholder 
    rlt_lst = [torch.zeros_like(lst) for _ in range(local_world_size)]
    dist.all_gather(rlt_lst, lst, async_op=False)
    print("After broadcasting:", rlt_lst)

Python-native Non-SPMD

PyTorch 不仅限于 SPMD 式的分布式训练。它还通过 torch.distributed.pipeline.sync.Pipe 和 PiPPy project 提供流水并行,其中流水并行的各个阶段在不同的设施上运行不同的程序。这些阶段常通过 torch.rpc 包来沟通。

分布式运行时机制

分布式 TensorFlow 作业由运行 TensorFlow runtime 程序的过程组成,而不是由 Python 解释器组成。此分布式运行时作业执行 TensorFlow graph (IR),它是由执行用户程序的 Python 解释器生成。

用户程序能够应用低级 API(如 tf.device)去指定作业要运行什么操作、在哪台设施和主机上运行等等。因为 API 有 runtime,所以能够做到这一点。

with tf.device('/job:bar/task:0/device:gpu:2'):
    # ops created here have the fully specified device above

与 PyTorch 一样,TensorFlow 也为分布式训练提供了高级 API tf.distributed.strategy,Keras 和 DTensor。

strategy = tf.distribute.MirroredStrategy() \
           if tf.config.list_physical_devices('GPU') \
           else tf.distribute.get_strategy()

with strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])

model.compile(loss='mse', optimizer='sgd')

分布式运行时极大中央便了训练服务的保护,因为咱们不再将用户程序打包到集群上运行。相同,咱们打包运行时程序,因为相比用户程序,运行时程序更加对立。

混合理念

JAX 反对 Python-native 和分布式运行时。

JAX 提供例如 vmappmappjit 的函数转换,这能够将 Python 函数转换为分布式程序。

(本文经受权后由 OneFlow 社区编译,译文转载请分割取得受权。原文:https://quip.com/Y8qtAyV4EXRg)

欢送 Star、试用 OneFlow 最新版本:https://github.com/Oneflow-In…

正文完
 0