共计 14102 个字符,预计需要花费 36 分钟才能阅读完成。
1 Session 概述
Session 是 TensorFlow 前后端连接的桥梁。用户利用 session 使得 client 能够与 master 的执行引擎建立连接,并通过 session.run() 来触发一次计算。它建立了一套上下文环境,封装了 operation 计算以及 tensor 求值的环境。
session 创建时,系统会分配一些资源,比如 graph 引用、要连接的计算引擎的名称等。故计算完毕后,需要使用 session.close() 关闭 session,避免引起内存泄漏,特别是 graph 无法释放的问题。可以显式调用 session.close(), 或利用 with 上下文管理器,或者直接使用 InteractiveSession。
session 之间采用共享 graph 的方式来提高运行效率。一个 session 只能运行一个 graph 实例,但一个 graph 可以运行在多个 session 中。一般情况下,创建 session 时如果不指定 Graph 实例,则会使用系统默认 Graph。常见情况下,我们都是使用一个 graph,即默认 graph。当 session 创建时,不会重新创建 graph 实例,而是默认 graph 引用计数加 1。当 session close 时,引用计数减 1。只有引用计数为 0 时,graph 才会被回收。这种 graph 共享的方式,大大减少了 graph 创建和回收的资源消耗,优化了 TensorFlow 运行效率。
2 默认 session
op 运算和 tensor 求值时,如果没有指定运行在哪个 session 中,则会运行在默认 session 中。通过 session.as_default() 可以将自己设置为默认 session。但个人建议最好还是通过 session.run(operator) 和 session.run(tensor) 来进行 op 运算和 tensor 求值。
operation.run()
operation.run() 等价于 tf.get_default_session().run(operation)
@tf_export(“Operation”)
class Operation(object):
# 通过 operation.run() 调用,进行 operation 计算
def run(self, feed_dict=None, session=None):
_run_using_default_session(self, feed_dict, self.graph, session)
def _run_using_default_session(operation, feed_dict, graph, session=None):
# 没有指定 session,则获取默认 session
if session is None:
session = get_default_session()
# 最终还是通过 session.run() 进行运行的。tf 中任何运算,都是通过 session 来 run 的。
# 通过 session 来建立 client 和 master 的连接,并将 graph 发送给 master,master 再进行执行
session.run(operation, feed_dict)
tensor.eval()
tensor.eval() 等价于 tf.get_default_session().run(tensor), 如下
@tf_export(“Tensor”)
class Tensor(_TensorLike):
# 通过 tensor.eval() 调用,进行 tensor 运算
def eval(self, feed_dict=None, session=None):
return _eval_using_default_session(self, feed_dict, self.graph, session)
def _eval_using_default_session(tensors, feed_dict, graph, session=None):
# 如果没有指定 session,则获取默认 session
if session is None:
session = get_default_session()
return session.run(tensors, feed_dict)
默认 session 的管理
tf 通过运行时维护的 session 本地线程栈,来管理默认 session。故不同的线程会有不同的默认 session,默认 session 是线程作用域的。
# session 栈
_default_session_stack = _DefaultStack()
# 获取默认 session 的接口
@tf_export(“get_default_session”)
def get_default_session():
return _default_session_stack.get_default()
# _DefaultStack 默认 session 栈是线程相关的
class _DefaultStack(threading.local):
# 默认 session 栈的创建,其实就是一个 list
def __init__(self):
super(_DefaultStack, self).__init__()
self._enforce_nesting = True
self.stack = []
# 获取默认 session
def get_default(self):
return self.stack[-1] if len(self.stack) >= 1 else None
3 前端 Session 类型
session 类图
会话 Session 的 UML 类图如下
分为两种类型,普通 Session 和交互式 InteractiveSession。InteractiveSession 和 Session 基本相同,区别在于
InteractiveSession 创建后,会将自己替换为默认 session。使得之后 operation.run() 和 tensor.eval() 的执行通过这个默认 session 来进行。特别适合 Python 交互式环境。
InteractiveSession 自带 with 上下文管理器。它在创建时和关闭时会调用上下文管理器的 enter 和 exit 方法,从而进行资源的申请和释放,避免内存泄漏问题。这同样很适合 Python 交互式环境。
Session 和 InteractiveSession 的代码逻辑不多,主要逻辑均在其父类 BaseSession 中。主要代码如下
@tf_export(‘Session’)
class Session(BaseSession):
def __init__(self, target=”, graph=None, config=None):
# session 创建的主要逻辑都在其父类 BaseSession 中
super(Session, self).__init__(target, graph, config=config)
self._default_graph_context_manager = None
self._default_session_context_manager = None
@tf_export(‘InteractiveSession’)
class InteractiveSession(BaseSession):
def __init__(self, target=”, graph=None, config=None):
self._explicitly_closed = False
# 将自己设置为 default session
self._default_session = self.as_default()
self._default_session.enforce_nesting = False
# 自动调用上下文管理器的__enter__() 方法
self._default_session.__enter__()
self._explicit_graph = graph
def close(self):
super(InteractiveSession, self).close()
## 省略无关代码
## 自动调用上下文管理器的__exit__() 方法,避免内存泄漏
self._default_session.__exit__(None, None, None)
self._default_session = None
BaseSession
BaseSession 基本包含了所有的会话实现逻辑。包括会话的整个生命周期,也就是创建 执行 关闭和销毁四个阶段。生命周期后面详细分析。BaseSession 包含的主要成员变量有 graph 引用,序列化的 graph_def, 要连接的 tf 引擎 target,session 配置信息 config 等。
4 后端 Session 类型
在后端 master 中,根据前端 client 调用 tf.Session(target=”, graph=None, config=None) 时指定的 target,来创建不同的 Session。target 为要连接的 tf 后端执行引擎,默认为空字符串。Session 创建采用了抽象工厂模式,如果为空字符串,则创建本地 DirectSession,如果以 grpc:// 开头,则创建分布式 GrpcSession。类图如下
DirectSession 只能利用本地设备,将任务创建到本地的 CPU GPU 上。而 GrpcSession 则可以利用远端分布式设备,将任务创建到其他机器的 CPU GPU 上,然后通过 grpc 协议进行通信。grpc 协议是谷歌发明并开源的远程通信协议。
5 Session 生命周期
Session 作为前后端连接的桥梁,以及上下文运行环境,其生命周期尤其关键。大致分为 4 个阶段
创建:通过 tf.Session() 创建 session 实例,进行系统资源分配,特别是 graph 引用计数加 1
运行:通过 session.run() 触发计算的执行,client 会将整图 graph 传递给 master,由 master 进行执行
关闭:通过 session.close() 来关闭,会进行系统资源的回收,特别是 graph 引用计数减 1.
销毁:Python 垃圾回收器进行 GC 时,调用 session.__del__() 进行回收。
生命周期方法入口基本都在前端 Python 的 BaseSession 中,它会通过 swig 自动生成的函数符号映射关系,调用 C 层的实现。
5.1 创建
先从 BaseSession 类的 init 方法看起,只保留了主要代码。
def __init__(self, target=”, graph=None, config=None):
# graph 表示构建的图。TensorFlow 的一个 session 会对应一个图。这个图包含了所有涉及到的算子
# graph 如果没有设置(通常都不会设置),则使用默认 graph
if graph is None:
self._graph = ops.get_default_graph()
else:
self._graph = graph
self._opened = False
self._closed = False
self._current_version = 0
self._extend_lock = threading.Lock()
# target 为要连接的 tf 执行引擎
if target is not None:
self._target = compat.as_bytes(target)
else:
self._target = None
self._delete_lock = threading.Lock()
self._dead_handles = []
# config 为 session 的配置信息
if config is not None:
self._config = config
self._add_shapes = config.graph_options.infer_shapes
else:
self._config = None
self._add_shapes = False
self._created_with_new_api = ops._USE_C_API
# 调用 C 层来创建 session
self._session = None
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
self._session = tf_session.TF_NewSession(self._graph._c_graph, opts, status)
BaseSession 先进行成员变量的赋值,然后调用 TF_NewSession 来创建 session。TF_NewSession() 方法由 swig 自动生成,在 bazel-bin/tensorflow/python/pywrap_tensorflow_internal.py 中
def TF_NewSession(graph, opts, status):
return _pywrap_tensorflow_internal.TF_NewSession(graph, opts, status)
_pywrap_tensorflow_internal 包含了 C 层函数的符号表。在 swig 模块 import 时,会加载 pywrap_tensorflow_internal.so 动态链接库,从而得到符号表。在 pywrap_tensorflow_internal.cc 中,注册了供 Python 调用的函数的符号表,从而实现 Python 到 C 的函数映射和调用。
// c++ 函数调用的符号表,Python 通过它可以调用到 C 层代码。符号表和动态链接库由 swig 自动生成
static PyMethodDef SwigMethods[] = {
// .. 省略其他函数定义
// TF_NewSession 的符号表,通过这个映射,Python 中就可以调用到 C 层代码了。
{(char *)”TF_NewSession”, _wrap_TF_NewSession, METH_VARARGS, NULL},
// … 省略其他函数定义
}
最终调用到 c_api.c 中的 TF_NewSession()
// TF_NewSession 创建 session 的新实现,在 C 层后端代码中
TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
TF_Status* status) {
Session* session;
// 创建 session
status->status = NewSession(opt->options, &session);
if (status->status.ok()) {
TF_Session* new_session = new TF_Session(session, graph);
if (graph != nullptr) {
// 采用了引用计数方式,多个 session 共享一个图实例,效率更高。
// session 创建时,引用计数加 1。session close 时引用计数减 1。引用计数为 0 时,graph 才会被回收。
mutex_lock l(graph->mu);
graph->sessions[new_session] = Status::OK();
}
return new_session;
} else {
DCHECK_EQ(nullptr, session);
return nullptr;
}
}
session 创建时,并创建 graph,而是采用共享方式,只是引用计数加 1 了。这种方式减少了 session 创建和关闭时的资源消耗,提高了运行效率。NewSession() 根据前端传递的 target,使用 sessionFactory 创建对应的 TensorFlow::Session 实例。
Status NewSession(const SessionOptions& options, Session** out_session) {
SessionFactory* factory;
const Status s = SessionFactory::GetFactory(options, &factory);
// 通过 sessionFactory 创建多态的 Session。本地 session 为 DirectSession,分布式为 GRPCSession
*out_session = factory->NewSession(options);
if (!*out_session) {
return errors::Internal(“Failed to create session.”);
}
return Status::OK();
}
创建 session 采用了抽象工厂模式。根据 client 传递的 target,来创建不同的 session。如果 target 为空字符串,则创建本地 DirectSession。如果以 grpc:// 开头,则创建分布式 GrpcSession。TensorFlow 包含本地运行时和分布式运行时两种运行模式。
下面来看 DirectSessionFactory 的 NewSession() 方法
class DirectSessionFactory : public SessionFactory {
public:
Session* NewSession(const SessionOptions& options) override {
std::vector<Device*> devices;
// job 在本地执行
const Status s = DeviceFactory::AddDevices(
options, “/job:localhost/replica:0/task:0”, &devices);
if (!s.ok()) {
LOG(ERROR) << s;
return nullptr;
}
DirectSession* session =
new DirectSession(options, new DeviceMgr(devices), this);
{
mutex_lock l(sessions_lock_);
sessions_.push_back(session);
}
return session;
}
GrpcSessionFactory 的 NewSession() 方法就不详细分析了,它会将 job 任务创建在分布式设备上,各 job 通过 grpc 协议通信。
5.2 运行
通过 session.run() 可以启动 graph 的执行。入口在 BaseSession 的 run() 方法中, 同样只列出关键代码
class BaseSession(SessionInterface):
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
# fetches 可以为单个变量,或者数组,或者元组。它是图的一部分,可以是操作 operation,也可以是数据 tensor,或者他们的名字 String
# feed_dict 为对应 placeholder 的实际训练数据,它的类型为字典
result = self._run(None, fetches, feed_dict, options_ptr,run_metadata_ptr)
return result
def _run(self, handle, fetches, feed_dict, options, run_metadata):
# 创建 fetch 处理器 fetch_handler
fetch_handler = _FetchHandler(
self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
# 经过不同类型的 fetch_handler 处理,得到最终的 fetches 和 targets
# targets 为要执行的 operation,fetches 为要执行的 tensor
_ = self._update_with_movers(feed_dict_tensor, feed_map)
final_fetches = fetch_handler.fetches()
final_targets = fetch_handler.targets()
# 开始运行
if final_fetches or final_targets or (handle and feed_dict_tensor):
results = self._do_run(handle, final_targets, final_fetches,
feed_dict_tensor, options, run_metadata)
else:
results = []
# 输出结果到 results 中
return fetch_handler.build_results(self, results)
def _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata):
# 将要运行的 operation 添加到 graph 中
self._extend_graph()
# 执行一次运行 run,会调用底层 C 来实现
return tf_session.TF_SessionPRunSetup_wrapper(
session, feed_list, fetch_list, target_list, status)
# 将要运行的 operation 添加到 graph 中
def _extend_graph(self):
with self._extend_lock:
if self._graph.version > self._current_version:
# 生成 graph_def 对象,它是 graph 的序列化表示
graph_def, self._current_version = self._graph._as_graph_def(
from_version=self._current_version, add_shapes=self._add_shapes)
# 通过 TF_ExtendGraph 将序列化后的 graph,也就是 graph_def 传递给后端
with errors.raise_exception_on_not_ok_status() as status:
tf_session.TF_ExtendGraph(self._session,
graph_def.SerializeToString(), status)
self._opened = True
逻辑还是十分复杂的,主要有一下几步
入参处理,创建 fetch 处理器 fetch_handler,得到最终要执行的 operation 和 tensor
对 graph 进行序列化,生成 graph_def 对象
将序列化后的 grap_def 对象传递给后端 master。
通过后端 master 来 run。
我们分别来看 extend 和 run。
5.2.1 extend 添加节点到 graph 中
TF_ExtendGraph() 会调用到 c_api 中,这个逻辑同样通过 swig 工具自动生成。下面看 c_api.cc 中的 TF_ExtendGraph() 方法
// 增加节点到 graph 中,proto 为序列化后的 graph
void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
size_t proto_len, TF_Status* status) {
GraphDef g;
// 先将 proto 反序列化,得到 client 传递的 graph,放入 g 中
if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
status->status = InvalidArgument(“Invalid GraphDef”);
return;
}
// 再调用 session 的 extend 方法。根据创建的不同 session 类型,多态调用不同方法。
status->status = s->session->Extend(g);
}
后端系统根据生成的 Session 类型,多态的调用 Extend 方法。如果是本地 session,则调用 DirectSession 的 Extend() 方法。如果是分布式 session,则调用 GrpcSession 的相关方法。下面来看 GrpcSession 的 Extend 方法。
Status GrpcSession::Extend(const GraphDef& graph) {
CallOptions call_options;
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
return ExtendImpl(&call_options, graph);
}
Status GrpcSession::ExtendImpl(CallOptions* call_options,
const GraphDef& graph) {
bool handle_is_empty;
{
mutex_lock l(mu_);
handle_is_empty = handle_.empty();
}
if (handle_is_empty) {
// 如果 graph 句柄为空,则表明 graph 还没有创建好,此时 extend 就等同于 create
return Create(graph);
}
mutex_lock l(mu_);
ExtendSessionRequest req;
req.set_session_handle(handle_);
*req.mutable_graph_def() = graph;
req.set_current_graph_version(current_graph_version_);
ExtendSessionResponse resp;
// 调用底层实现,来添加节点到 graph 中
Status s = master_->ExtendSession(call_options, &req, &resp);
if (s.ok()) {
current_graph_version_ = resp.new_graph_version();
}
return s;
}
Extend() 方法中要注意的一点是,如果是首次执行 Extend(), 则要先调用 Create() 方法进行 graph 的注册。否则才是执行添加节点到 graph 中。
5.2.2 run 执行图的计算
同样,Python 通过 swig 自动生成的代码,来实现对 C API 的调用。C 层实现在 c_api.cc 的 TF_Run() 中。
// session.run() 的 C 层实现
void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
// Input tensors,输入的数据 tensor
const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
// Output tensors,运行计算后输出的数据 tensor
const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
// Target nodes,要运行的节点
const char** c_target_oper_names, int ntargets,
TF_Buffer* run_metadata, TF_Status* status) {
// 省略一段代码
TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
c_outputs, target_oper_names, run_metadata, status);
}
// 真正的实现了 session.run()
static void TF_Run_Helper() {
RunMetadata run_metadata_proto;
// 调用不同的 session 实现类的 run 方法,来执行
result = session->Run(run_options_proto, input_pairs, output_tensor_names,
target_oper_names, &outputs, &run_metadata_proto);
// 省略代码
}
最终会调用创建的 session 来执行 run 方法。DirectSession 和 GrpcSession 的 Run() 方法会有所不同。后面很复杂,就不接着分析了。
5.3 关闭 session
通过 session.close() 来关闭 session,释放相关资源,防止内存泄漏。
class BaseSession(SessionInterface):
def close(self):
tf_session.TF_CloseSession(self._session, status)
会调用到 C API 的 TF_CloseSession() 方法。
void TF_CloseSession(TF_Session* s, TF_Status* status) {
status->status = s->session->Close();
}
最终根据创建的 session,多态的调用其 Close() 方法。同样分为 DirectSession 和 GrpcSession 两种。
::tensorflow::Status DirectSession::Close() {
cancellation_manager_->StartCancel();
{
mutex_lock l(closed_lock_);
if (closed_) return ::tensorflow::Status::OK();
closed_ = true;
}
// 注销 session
if (factory_ != nullptr) factory_->Deregister(this);
return ::tensorflow::Status::OK();
}
DirectSessionFactory 中的 Deregister() 方法如下
void Deregister(const DirectSession* session) {
mutex_lock l(sessions_lock_);
// 释放相关资源
sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
sessions_.end());
}
5.4 销毁 session
session 的销毁是由 Python 的 GC 自动执行的。python 通过引用计数方法来判断是否回收对象。当对象的引用计数为 0,且虚拟机触发了 GC 时,会调用对象的__del__() 方法来销毁对象。引用计数法有个很致命的问题,就是无法解决循环引用问题,故会存在内存泄漏。Java 虚拟机采用了调用链分析的方式来决定哪些对象会被回收。
class BaseSession(SessionInterface):
def __del__(self):
# 先 close,防止用户没有调用 close()
try:
self.close()
# 再调用 c api 的 TF_DeleteSession 来销毁 session
if self._session is not None:
try:
status = c_api_util.ScopedTFStatus()
if self._created_with_new_api:
tf_session.TF_DeleteSession(self._session, status)
c_api.cc 中的相关逻辑如下
void TF_DeleteSession(TF_Session* s, TF_Status* status) {
status->status = Status::OK();
TF_Graph* const graph = s->graph;
if (graph != nullptr) {
graph->mu.lock();
graph->sessions.erase(s);
// 如果 graph 的引用计数为 0,也就是 graph 没有被任何 session 持有,则考虑销毁 graph 对象
const bool del = graph->delete_requested && graph->sessions.empty();
graph->mu.unlock();
// 销毁 graph 对象
if (del) delete graph;
}
// 销毁 session 和 TF_Session
delete s->session;
delete s;
}
TF_DeleteSession() 会判断 graph 的引用计数是否为 0,如果为 0,则会销毁 graph。然后销毁 session 和 TF_Session 对象。通过 Session 实现类的析构函数,来销毁 session,释放线程池 Executor,资源管理器 ResourceManager 等资源。
DirectSession::~DirectSession() {
for (auto& it : partial_runs_) {
it.second.reset(nullptr);
}
// 释放线程池 Executor
for (auto& it : executors_) {
it.second.reset();
}
for (auto d : device_mgr_->ListDevices()) {
d->op_segment()->RemoveHold(session_handle_);
}
// 释放 ResourceManager
for (auto d : device_mgr_->ListDevices()) {
d->ClearResourceMgr();
}
// 释放 CancellationManager 实例
functions_.clear();
delete cancellation_manager_;
// 释放 ThreadPool
for (const auto& p_and_owned : thread_pools_) {
if (p_and_owned.second) delete p_and_owned.first;
}
execution_state_.reset(nullptr);
flib_def_.reset(nullptr);
}
6 总结
Session 是 TensorFlow 的 client 和 master 连接的桥梁,client 任何运算也是通过 session 来 run。它是 client 端最重要的对象。在 Python 层和 C ++ 层,均有不同的 session 实现。session 生命周期会经历四个阶段,create run close 和 del。四个阶段均由 Python 前端开始,最终调用到 C 层后端实现。由此也可以看到,TensorFlow 框架的前后端分离和模块化设计是多么的精巧。
本文作者:扬易阅读原文
本文为云栖社区原创内容,未经允许不得转载。