共计 5763 个字符,预计需要花费 15 分钟才能阅读完成。
1 Graph 概述
计算图 Graph 是 TensorFlow 的核心对象,TensorFlow 的运行流程基本都是围绕它进行的。包括图的构建、传递、剪枝、按 worker 分裂、按设备二次分裂、执行、注销等。因此理解计算图 Graph 对掌握 TensorFlow 运行尤为关键。
2 默认 Graph
默认图替换
之前讲解 Session 的时候就说过,一个 Session 只能 run 一个 Graph,但一个 Graph 可以运行在多个 Session 中。常见情况是,session 会运行全局唯一的隐式的默认的 Graph,operation 也是注册到这个 Graph 中。
也可以显示创建 Graph,并调用 as_default() 使他替换默认 Graph。在该上下文管理器中创建的 op 都会注册到这个 graph 中。退出上下文管理器后,则恢复原来的默认 graph。一般情况下,我们不用显式创建 Graph,使用系统创建的那个默认 Graph 即可。
print tf.get_default_graph()
with tf.Graph().as_default() as g:
print tf.get_default_graph() is g
print tf.get_default_graph()
print tf.get_default_graph()
输出如下
<tensorflow.python.framework.ops.Graph object at 0x106329fd0>
True
<tensorflow.python.framework.ops.Graph object at 0x18205cc0d0>
<tensorflow.python.framework.ops.Graph object at 0x10d025fd0>
由此可见,在上下文管理器中,当前线程的默认图被替换了,而退出上下文管理后,则恢复为了原来的默认图。
默认图管理
默认 graph 和默认 session 一样,也是线程作用域的。当前线程中,永远都有且仅有一个 graph 为默认图。TensorFlow 同样通过栈来管理线程的默认 graph。
@tf_export(“Graph”)
class Graph(object):
# 替换线程默认图
def as_default(self):
return _default_graph_stack.get_controller(self)
# 栈式管理,push pop
@tf_contextlib.contextmanager
def get_controller(self, default):
try:
context.context_stack.push(default.building_function, default.as_default)
finally:
context.context_stack.pop()
替换默认图采用了堆栈的管理方式,通过 push pop 操作进行管理。获取默认图的操作如下,通过默认 graph 栈_default_graph_stack 来获取。
@tf_export(“get_default_graph”)
def get_default_graph():
return _default_graph_stack.get_default()
下面来看_default_graph_stack 的创建
_default_graph_stack = _DefaultGraphStack()
class _DefaultGraphStack(_DefaultStack):
def __init__(self):
# 调用父类来创建
super(_DefaultGraphStack, self).__init__()
self._global_default_graph = None
class _DefaultStack(threading.local):
def __init__(self):
super(_DefaultStack, self).__init__()
self._enforce_nesting = True
# 和默认 session 栈一样,本质上也是一个 list
self.stack = []
_default_graph_stack 的创建如上所示,最终和默认 session 栈一样,本质上也是一个 list。
3 前端 Graph 数据结构
Graph 数据结构
理解一个对象,先从它的数据结构开始。我们先来看 Python 前端中,Graph 的数据结构。Graph 主要的成员变量是 Operation 和 Tensor。Operation 是 Graph 的节点,它代表了运算算子。Tensor 是 Graph 的边,它代表了运算数据。
@tf_export(“Graph”)
class Graph(object):
def __init__(self):
# 加线程锁,使得注册 op 时,不会有其他线程注册 op 到 graph 中,从而保证共享 graph 是线程安全的
self._lock = threading.Lock()
# op 相关数据。
# 为 graph 的每个 op 分配一个 id,通过 id 可以快速索引到相关 op。故创建了_nodes_by_id 字典
self._nodes_by_id = dict() # GUARDED_BY(self._lock)
self._next_id_counter = 0 # GUARDED_BY(self._lock)
# 同时也可以通过 name 来快速索引 op,故创建了_nodes_by_name 字典
self._nodes_by_name = dict() # GUARDED_BY(self._lock)
self._version = 0 # GUARDED_BY(self._lock)
# tensor 相关数据。
# 处理 tensor 的 placeholder
self._handle_feeders = {}
# 处理 tensor 的 read 操作
self._handle_readers = {}
# 处理 tensor 的 move 操作
self._handle_movers = {}
# 处理 tensor 的 delete 操作
self._handle_deleters = {}
下面看 graph 如何添加 op 的,以及保证线程安全的。
def _add_op(self, op):
# graph 被设置为 final 后,就是只读的了,不能添加 op 了。
self._check_not_finalized()
# 保证共享 graph 的线程安全
with self._lock:
# 将 op 以 id 和 name 分别构建字典,添加到_nodes_by_id 和_nodes_by_name 字典中,方便后续快速索引
self._nodes_by_id[op._id] = op
self._nodes_by_name[op.name] = op
self._version = max(self._version, op._id)
GraphKeys 图分组
每个 Operation 节点都有一个特定的标签,从而实现节点的分类。相同标签的节点归为一类,放到同一个 Collection 中。标签是一个唯一的 GraphKey,GraphKey 被定义在类 GraphKeys 中,如下
@tf_export(“GraphKeys”)
class GraphKeys(object):
GLOBAL_VARIABLES = “variables”
QUEUE_RUNNERS = “queue_runners”
SAVERS = “savers”
WEIGHTS = “weights”
BIASES = “biases”
ACTIVATIONS = “activations”
UPDATE_OPS = “update_ops”
LOSSES = “losses”
TRAIN_OP = “train_op”
# 省略其他
name_scope 节点命名空间
使用 name_scope 对 graph 中的节点进行层次化管理,上下层之间通过斜杠分隔。
# graph 节点命名空间
g = tf.get_default_graph()
with g.name_scope(“scope1”):
c = tf.constant(“hello, world”, name=”c”)
print c.op.name
with g.name_scope(“scope2”):
c = tf.constant(“hello, world”, name=”c”)
print c.op.name
输出如下
scope1/c
scope1/scope2/c # 内层的 scope 会继承外层的,类似于栈,形成层次化管理
4 后端 Graph 数据结构
Graph
先来看 graph.h 文件中的 Graph 类的定义,只看关键代码
class Graph {
private:
// 所有已知的 op 计算函数的注册表
FunctionLibraryDefinition ops_;
// GraphDef 版本号
const std::unique_ptr<VersionDef> versions_;
// 节点 node 列表,通过 id 来访问
std::vector<Node*> nodes_;
// node 个数
int64 num_nodes_ = 0;
// 边 edge 列表,通过 id 来访问
std::vector<Edge*> edges_;
// graph 中非空 edge 的数目
int num_edges_ = 0;
// 已分配了内存,但还没使用的 node 和 edge
std::vector<Node*> free_nodes_;
std::vector<Edge*> free_edges_;
}
后端中的 Graph 主要成员也是节点 node 和边 edge。节点 node 为计算算子 Operation,边为算子所需要的数据,或者代表节点间的依赖关系。这一点和 Python 中的定义相似。边 Edge 的持有它的源节点和目标节点的指针,从而将两个节点连接起来。下面看 Edge 类的定义。
Edge
class Edge {
private:
Edge() {}
friend class EdgeSetTest;
friend class Graph;
// 源节点, 边的数据就来源于源节点的计算。源节点是边的生产者
Node* src_;
// 目标节点,边的数据提供给目标节点进行计算。目标节点是边的消费者
Node* dst_;
// 边 id,也就是边的标识符
int id_;
// 表示当前边为源节点的第 src_output_条边。源节点可能会有多条输出边
int src_output_;
// 表示当前边为目标节点的第 dst_input_条边。目标节点可能会有多条输入边。
int dst_input_;
};
Edge 既可以承载 tensor 数据,提供给节点 Operation 进行运算,也可以用来表示节点之间有依赖关系。对于表示节点依赖的边,其 src_output_, dst_input_均为 -1,此时边不承载任何数据。
下面来看 Node 类的定义。
Node
class Node {
public:
// NodeDef, 节点算子 Operation 的信息,比如 op 分配到哪个设备上了,op 的名字等,运行时有可能变化。
const NodeDef& def() const;
// OpDef, 节点算子 Operation 的元数据,不会变的。比如 Operation 的入参列表,出参列表等
const OpDef& op_def() const;
private:
// 输入边,传递数据给节点。可能有多条
EdgeSet in_edges_;
// 输出边,节点计算后得到的数据。可能有多条
EdgeSet out_edges_;
}
节点 Node 中包含的主要数据有输入边和输出边的集合,从而能够由 Node 找到跟他关联的所有边。Node 中还包含 NodeDef 和 OpDef 两个成员。NodeDef 表示节点算子的信息,运行时可能会变,创建 Node 时会 new 一个 NodeDef 对象。OpDef 表示节点算子的元信息,运行时不会变,创建 Node 时不需要 new OpDef,只需要从 OpDef 仓库中取出即可。因为元信息是确定的,比如 Operation 的入参个数等。
由 Node 和 Edge,即可以组成图 Graph,通过任何节点和任何边,都可以遍历完整图。Graph 执行计算时,按照拓扑结构,依次执行每个 Node 的 op 计算,最终即可得到输出结果。入度为 0 的节点,也就是依赖数据已经准备好的节点,可以并发执行,从而提高运行效率。
系统中存在默认的 Graph,初始化 Graph 时,会添加一个 Source 节点和 Sink 节点。Source 表示 Graph 的起始节点,Sink 为终止节点。Source 的 id 为 0,Sink 的 id 为 1,其他节点 id 均大于 1.
5 Graph 运行时生命周期
Graph 是 TensorFlow 的核心对象,TensorFlow 的运行均是围绕 Graph 进行的。运行时 Graph 大致经过了以下阶段
图构建:client 端用户将创建的节点注册到 Graph 中,一般不需要显示创建 Graph,使用系统创建的默认的即可。
图发送:client 通过 session.run() 执行运行时,将构建好的整图序列化为 GraphDef 后,传递给 master
图剪枝:master 先反序列化拿到 Graph,然后根据 session.run() 传递的 fetches 和 feeds 列表,反向遍历全图 full graph,实施剪枝,得到最小依赖子图。
图分裂:master 将最小子图分裂为多个 Graph Partition,并注册到多个 worker 上。一个 worker 对应一个 Graph Partition。
图二次分裂:worker 根据当前可用硬件资源,如 CPU GPU,将 Graph Partition 按照 op 算子设备约束规范(例如 tf.device(’/cpu:0’),二次分裂到不同设备上。每个计算设备对应一个 Graph Partition。
图运行:对于每一个计算设备,worker 依照 op 在 kernel 中的实现,完成 op 的运算。设备间数据通信可以使用 send/recv 节点,而 worker 间通信,则使用 GRPC 或 RDMA 协议。
这些阶段根据 TensorFlow 运行时的不同,会进行不同的处理。运行时有两种,本地运行时和分布式运行时。故 Graph 生命周期到后面分析本地运行时和分布式运行时的时候,再详细讲解。
本文作者:扬易阅读原文
本文为云栖社区原创内容,未经允许不得转载。